vllm - ✅(Solved) Fix [Bug]: FlashInfer TRTLLM monolithic MoE produces 0% accuracy for Qwen3.5-35B/122B FP8 [3 pull requests, 3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37591Fetched 2026-04-08 01:04:31
View on GitHub
Comments
3
Participants
2
Timeline
33
Reactions
0
Timeline (top)
referenced ×14project_v2_item_status_changed ×5commented ×3cross-referenced ×3

The flashinfer_trtllm MoE backend selects the monolithic kernel (trtllm_fp8_block_scale_moe) for models using Renormalize routing. This kernel has a bug where all-negative router logits cause incorrect expert routing, producing 0% GSM8K accuracy for Qwen3.5-35B-A3B-FP8 and Qwen3.5-122B-A10B-FP8 with any DP+EP configuration.

Upstream FlashInfer issue: https://github.com/flashinfer-ai/flashinfer/issues/2822

Root Cause

The flashinfer_trtllm MoE backend selects the monolithic kernel (trtllm_fp8_block_scale_moe) for models using Renormalize routing. This kernel has a bug where all-negative router logits cause incorrect expert routing, producing 0% GSM8K accuracy for Qwen3.5-35B-A3B-FP8 and Qwen3.5-122B-A10B-FP8 with any DP+EP configuration.

Upstream FlashInfer issue: https://github.com/flashinfer-ai/flashinfer/issues/2822

Fix Action

Workaround

Remove Renormalize / RenormalizeNaive from TrtLlmFp8ExpertsMonolithic._supports_routing_method() in vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py. This causes the backend to select the modular variant (TrtLlmFp8ExpertsModular) which performs routing in Python and is not affected. Accuracy restored to 75-80%.

PR fix notes

PR #37605: [Bugfix] Disable monolithic TRTLLM MoE for Renormalize routing (#37591)

Description (problem / solution / changelog)

Summary

The FlashInfer TRTLLM monolithic MoE kernel (trtllm_fp8_block_scale_moe) produces incorrect expert routing when all router logits are negative. This causes 0% GSM8K accuracy for Qwen3.5-35B-A3B-FP8 and Qwen3.5-122B-A10B-FP8 with DP+EP.

This PR removes Renormalize and RenormalizeNaive from TrtLlmFp8ExpertsMonolithic._supports_routing_method(), causing the backend selection to fall through to TrtLlmFp8ExpertsModular. The modular variant performs routing in Python (float32 softmax) and calls the routed kernel, which is not affected by the bug.

Fixes #37591, https://github.com/flashinfer-ai/flashinfer/issues/2822

Test plan

  • python -m pytest test_trtllm_ep_bug.py -v -s -k "shift" reproduces the kernel routing divergence (cosine=0.24 with all-negative logits)
  • GSM8K eval: Qwen3.5-35B-A3B-FP8 DP2+EP accuracy restored from 0% to 86%

Changed files

  • .buildkite/test_areas/lm_eval.yaml (modified, +16/-0)
  • tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-DEP2.yaml (added, +8/-0)
  • tests/evals/gsm8k/configs/Qwen3.5-35B-A3B-FP8-DEP2.yaml (added, +9/-0)
  • tests/evals/gsm8k/configs/models-qwen35-blackwell.txt (added, +1/-0)
  • vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py (modified, +8/-5)

PR #36838: enable flashinfer moe kernel for DP + EP

Description (problem / solution / changelog)

Purpose

Previously the BF16 flashinfer moe kernel is disabled when dp > 1. I think the kernel itself should be able to support it, we just need to enable on the vLLM side. Also add test to verify the kernel selection logic works as intended.

Test Plan

pytest tests/kernels/moe/test_unquantized_backend_selection.py run gsm8k with bf16 qwen 3a30b on 2xB200 DP2 EP2 and compare the result with different moe backend.

server command

# triton/default backend
vllm serve Qwen/Qwen3-30B-A3B \
  --data-parallel-size 2 \
  --enable-expert-parallel \
  --trust-remote-code \
  --port 8000

# flashinfer cutlass
VLLM_USE_FLASHINFER_MOE_FP16=1 VLLM_FLASHINFER_MOE_BACKEND=throughput vllm serve Qwen/Qwen3-30B-A3B \
  --data-parallel-size 2 \
  --enable-expert-parallel \
  --trust-remote-code \
  --port 8000

# flashinfer trtllm
VLLM_USE_FLASHINFER_MOE_FP16=1 VLLM_FLASHINFER_MOE_BACKEND=latency vllm serve Qwen/Qwen3-30B-A3B \
  --data-parallel-size 2 \
  --enable-expert-parallel \
  --trust-remote-code \
  --port 8000

test command

python -m lm_eval \
  --model local-completions \
  --model_args "model=Qwen/Qwen3-30B-A3B,base_url=http://localhost:8000/v1/completions,num_concurrent=128,max_retries=5,tokenized_requests=False,tokenizer=Qwen/Qwen3-30B-A3B" \
  --tasks gsm8k_cot \
  --batch_size auto \
  --log_samples \
  --output_path /tmp/lm_eval_qwen_dp2_ep

Test Result

Backendflexible-extractstderr
Triton (default)0.8870±0.0087
FlashInfer CUTLASS (throughput)0.8992±0.0083
FlashInfer TRTLLM (latency)0.9007±0.0082

pytest tests/kernels/moe/test_unquantized_backend_selection.py pass


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • tests/kernels/moe/test_unquantized_backend_selection.py (modified, +88/-2)
  • vllm/model_executor/layers/fused_moe/oracle/unquantized.py (modified, +1/-7)
  • vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py (modified, +8/-0)

PR #38859: [Bugfix] Re-enable Renormalize routing for TRT-LLM MoE experts

Description (problem / solution / changelog)

Purpose

Re-enable Renormalize and RenormalizeNaive routing for TRT-LLM MoE experts (BF16 and FP8).

These were disabled in #37591 because the monolithic kernel's internal Renormalize routing produced output uncorrelated with the modular/Triton kernel for Qwen3.5 models. The root cause was a flashinfer bug (https://github.com/flashinfer-ai/flashinfer/issues/2822), fixed in 0.6.7.

Test Plan

Verified with reproduction scripts that compare monolithic vs modular kernel output (cosine similarity) under all-negative router logits — the condition that triggers the bug in Qwen3.5 models.

<details> <summary>test_trtllm_ep_bug_fp8.py (from flashinfer-ai/flashinfer#2822)</summary>
"""
Reproduce: trtllm_fp8_block_scale_moe monolithic kernel produces incorrect
routing when all router logits are negative (flat softmax distribution).

The monolithic kernel's internal Renormalize routing diverges from the
modular kernel (Python f32 routing) when all 256 expert logits are negative.
This occurs naturally in Qwen3.5 models where the gate layer produces
all-negative logits.

Trigger: when max(router_logits) < 0, the kernel's internal softmax
routing selects different top-8 experts than Python's float32 softmax,
producing uncorrelated MoE output (cosine ~0.3).
"""
import pytest
import torch


@pytest.fixture
def model_weights():
    """Load real Qwen3.5-35B-A3B-FP8 layer-0 EP0 weights if available,
    otherwise use random weights (bug still reproduces)."""
    from pathlib import Path
    dump = Path(__file__).parent / "kernel_dump_ep0.pt"
    if dump.exists():
        d = torch.load(str(dump), map_location="cpu", weights_only=False)
        return (d["w1"].cuda(), d["w1s"].cuda(),
                d["w2"].cuda(), d["w2s"].cuda())

    torch.manual_seed(0)
    device = "cuda"
    ne, inter, hid = 128, 512, 2048
    w1 = torch.randn(ne, 2*inter, hid, dtype=torch.bfloat16, device=device
                      ).to(torch.float8_e4m3fn)
    w1s = torch.ones(ne, 2*inter//128, hid//128, dtype=torch.float32,
                      device=device)
    w2 = torch.randn(ne, hid, inter, dtype=torch.bfloat16, device=device
                      ).to(torch.float8_e4m3fn)
    w2s = torch.ones(ne, hid//128, inter//128, dtype=torch.float32,
                      device=device)
    return w1, w1s, w2, w2s


def _run_both_kernels(logits, hidden_fp8, scale, w1, w1s, w2, w2s,
                      num_experts=256, local_ne=128, offset=0, top_k=8):
    import flashinfer
    from flashinfer.fused_moe import Fp8QuantizationType

    mono = flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
        routing_logits=logits, routing_bias=None,
        hidden_states=hidden_fp8, hidden_states_scale=scale,
        gemm1_weights=w1, gemm1_weights_scale=w1s,
        gemm2_weights=w2, gemm2_weights_scale=w2s,
        num_experts=num_experts, top_k=top_k,
        n_group=0, topk_group=0,
        intermediate_size=w2.shape[2],
        local_expert_offset=offset, local_num_experts=local_ne,
        routed_scaling_factor=1.0, routing_method_type=1,
        use_shuffled_weight=False,
        fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8,
    )

    scores = torch.softmax(logits.float(), dim=-1)
    tw, ti = torch.topk(scores, k=top_k, dim=-1)
    tw = tw / tw.sum(dim=-1, keepdim=True)
    packed = (ti.int() << 16) | tw.to(torch.bfloat16).view(torch.int16)

    mod = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe(
        topk_ids=packed, routing_bias=None,
        hidden_states=hidden_fp8, hidden_states_scale=scale,
        gemm1_weights=w1, gemm1_weights_scale=w1s,
        gemm2_weights=w2, gemm2_weights_scale=w2s,
        num_experts=num_experts, top_k=top_k,
        n_group=None, topk_group=None,
        intermediate_size=w2.shape[2],
        local_expert_offset=offset, local_num_experts=local_ne,
        routed_scaling_factor=None, routing_method_type=1,
        use_shuffled_weight=False, weight_layout=0,
        fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8,
    )

    mf, rf = mono.float(), mod.float()
    if mf.norm() < 1e-6 and rf.norm() < 1e-6:
        return 1.0
    return torch.nn.functional.cosine_similarity(
        mf.reshape(1, -1), rf.reshape(1, -1)).item()


@pytest.mark.parametrize("logit_shift", [0.0, -6.0, -10.0])
def test_all_negative_logits_trigger_routing_divergence(model_weights,
                                                         logit_shift):
    """
    When all router logits are negative (shifted below zero), the monolithic
    kernel's Renormalize routing diverges from Python's float32 routing.

    shift=0:   logits in [-8, +9] -> cosine ~1.0 (OK)
    shift=-6:  logits in [-14, +3] -> cosine ~1.0 (OK, still some positive)
    shift=-10: logits in [-18, -1] -> cosine ~0.3 (BUG: all negative)

    FAILS for shift=-10 because cosine drops below 0.9.
    """
    w1, w1s, w2, w2s = model_weights
    device = "cuda"
    torch.manual_seed(42)

    logits = (torch.randn(4, 256, dtype=torch.bfloat16, device=device) * 3.0
              + logit_shift)
    hidden = torch.randn(4, 2048, dtype=torch.bfloat16, device=device
                         ).to(torch.float8_e4m3fn)
    scale = torch.ones(16, 4, dtype=torch.float32, device=device) * 0.01

    cos = _run_both_kernels(logits, hidden, scale, w1, w1s, w2, w2s)

    assert cos > 0.9, (
        f"Monolithic kernel routing diverges from modular (cosine={cos:.4f}) "
        f"when logits are shifted by {logit_shift} "
        f"(range [{logits.float().min():.1f}, {logits.float().max():.1f}]). "
        f"The kernel's internal Renormalize routing fails with all-negative "
        f"logits, producing incorrect expert selections."
    )
</details> <details> <summary>test_trtllm_ep_bug_bf16.py</summary>
"""
Reproduce: trtllm_bf16_moe monolithic kernel produces incorrect
routing when all router logits are negative (flat softmax distribution).

The monolithic kernel's internal Renormalize routing diverges from the
modular kernel (Python f32 routing) when all 256 expert logits are negative.
This occurs naturally in Qwen3.5 models where the gate layer produces
all-negative logits.

Trigger: when max(router_logits) < 0, the kernel's internal softmax
routing selects different top-8 experts than Python's float32 softmax,
producing uncorrelated MoE output (cosine ~0.3).
"""
import pytest
import torch


def _convert_to_block_layout(w13: torch.Tensor, w2: torch.Tensor):
    """Convert BF16 expert weights to the BlockMajorK layout required by the
    trtllm_bf16_moe kernel.

    Copied from flashinfer.fused_moe.core (private helpers) and
    vllm/.../flashinfer_utils.py:convert_moe_weights_to_flashinfer_trtllm_block_layout.
    """
    from flashinfer.fused_moe.core import (
        _maybe_get_cached_w3_w1_permute_indices,
        convert_to_block_layout,
        get_w2_permute_indices_with_cache,
    )

    epilogue_tile_m = 128
    block_k = 128
    ne = w13.shape[0]
    cache: dict[torch.Size, torch.Tensor] = {}

    w13_out, w2_out = [], []
    for i in range(ne):
        # Steps 1+2 for W13: gated-act interleave then epilogue-tile shuffle
        pi = _maybe_get_cached_w3_w1_permute_indices(
            cache, w13[i].view(torch.uint8), epilogue_tile_m)
        tmp1 = w13[i].clone().view(torch.uint8)[pi.to(w13.device)].contiguous()

        # Step 2 for W2: epilogue-tile shuffle only (no gated-act reorder)
        pi2 = get_w2_permute_indices_with_cache(
            cache, w2[i].view(torch.uint8), epilogue_tile_m)
        tmp2 = w2[i].clone().view(torch.uint8)[pi2.to(w2.device)].contiguous()

        # Step 3: [M, K] → [K/128, M, 128] block tiling
        w13_out.append(
            convert_to_block_layout(tmp1, block_k).view(torch.bfloat16))
        w2_out.append(
            convert_to_block_layout(tmp2, block_k).view(torch.bfloat16))

    return (torch.stack(w13_out).contiguous(),
            torch.stack(w2_out).contiguous())


@pytest.fixture
def model_weights():
    """Generate random BF16 weights and convert to BlockMajorK layout."""
    torch.manual_seed(0)
    device = "cuda"
    ne, inter, hid = 128, 512, 2048
    w1 = torch.randn(ne, 2 * inter, hid, dtype=torch.bfloat16, device=device)
    w2 = torch.randn(ne, hid, inter, dtype=torch.bfloat16, device=device)
    w1, w2 = _convert_to_block_layout(w1, w2)
    return w1, w2, inter


def _run_both_kernels(logits, hidden, w1, w2, intermediate_size,
                      num_experts=256, local_ne=128, offset=0, top_k=8):
    import flashinfer

    mono = flashinfer.fused_moe.trtllm_bf16_moe(
        routing_logits=logits,
        routing_bias=None,
        hidden_states=hidden,
        gemm1_weights=w1,
        gemm2_weights=w2,
        num_experts=num_experts,
        top_k=top_k,
        n_group=0,
        topk_group=0,
        intermediate_size=intermediate_size,
        local_expert_offset=offset,
        local_num_experts=local_ne,
        routed_scaling_factor=1.0,
        routing_method_type=1,
        tune_max_num_tokens=4,
    )

    scores = torch.softmax(logits.float(), dim=-1)
    tw, ti = torch.topk(scores, k=top_k, dim=-1)
    tw = tw / tw.sum(dim=-1, keepdim=True)
    packed = (ti.to(torch.int32) << 16) | (
        tw.to(torch.bfloat16).view(torch.int16).to(torch.int32) & 0xFFFF
    )
    mod = flashinfer.fused_moe.trtllm_bf16_routed_moe(
        topk_ids=packed,
        hidden_states=hidden,
        gemm1_weights=w1,
        gemm2_weights=w2,
        num_experts=num_experts,
        top_k=top_k,
        n_group=None,
        topk_group=None,
        intermediate_size=intermediate_size,
        local_expert_offset=offset,
        local_num_experts=local_ne,
        routed_scaling_factor=None,
        routing_method_type=1,
        tune_max_num_tokens=4,
    )

    mf, rf = mono.float(), mod.float()
    if mf.norm() < 1e-6 and rf.norm() < 1e-6:
        return 1.0
    return torch.nn.functional.cosine_similarity(
        mf.reshape(1, -1), rf.reshape(1, -1)).item()


@pytest.mark.parametrize("logit_shift", [0.0, -6.0, -10.0])
def test_all_negative_logits_trigger_routing_divergence(model_weights,
                                                         logit_shift):
    """
    When all router logits are negative (shifted below zero), the monolithic
    kernel's Renormalize routing diverges from Python's float32 routing.

    shift=0:   logits in [-8, +9] -> cosine ~1.0 (OK)
    shift=-6:  logits in [-14, +3] -> cosine ~1.0 (OK, still some positive)
    shift=-10: logits in [-18, -1] -> cosine ~0.3 (BUG: all negative)

    FAILS for shift=-10 because cosine drops below 0.9.
    """
    w1, w2, inter = model_weights
    device = "cuda"
    torch.manual_seed(42)

    logits = (torch.randn(4, 256, dtype=torch.bfloat16, device=device) * 3.0
              + logit_shift)
    hidden = torch.randn(4, 2048, dtype=torch.bfloat16, device=device)

    cos = _run_both_kernels(logits, hidden, w1, w2, inter)

    assert cos > 0.9, (
        f"Monolithic kernel routing diverges from modular (cosine={cos:.4f}) "
        f"when logits are shifted by {logit_shift} "
        f"(range [{logits.float().min():.1f}, {logits.float().max():.1f}]). "
        f"The kernel's internal Renormalize routing fails with all-negative "
        f"logits, producing incorrect expert selections."
    )
</details>

Test Result

Both scripts fail on flashinfer 0.6.6 (flashinfer-python 0.6.6 and flashinfer-jit-cache 0.6.6+cu13) and pass on flashinfer 0.6.7+cu130. Also verified passing with 0.6.7+cu128.

Qwen3.5-35B-A3B-DEP2.yaml
Qwen3.5-35B-A3B-FP8-DEP2.yaml

also pass with TRTLLM backend.


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • vllm/model_executor/layers/fused_moe/experts/trtllm_bf16_moe.py (modified, +2/-5)
  • vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py (modified, +5/-7)
RAW_BUFFERClick to expand / collapse

Description

The flashinfer_trtllm MoE backend selects the monolithic kernel (trtllm_fp8_block_scale_moe) for models using Renormalize routing. This kernel has a bug where all-negative router logits cause incorrect expert routing, producing 0% GSM8K accuracy for Qwen3.5-35B-A3B-FP8 and Qwen3.5-122B-A10B-FP8 with any DP+EP configuration.

Upstream FlashInfer issue: https://github.com/flashinfer-ai/flashinfer/issues/2822

Workaround

Remove Renormalize / RenormalizeNaive from TrtLlmFp8ExpertsMonolithic._supports_routing_method() in vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py. This causes the backend to select the modular variant (TrtLlmFp8ExpertsModular) which performs routing in Python and is not affected. Accuracy restored to 75-80%.

extent analysis

Fix Plan

To resolve the issue with incorrect expert routing in the flashinfer_trtllm MoE backend, we will modify the TrtLlmFp8ExpertsMonolithic class.

Steps

  • Modify the _supports_routing_method() function in trtllm_fp8_moe.py to exclude Renormalize and RenormalizeNaive routing methods.
  • This change will cause the backend to select the modular variant (TrtLlmFp8ExpertsModular) for models using these routing methods.

Example Code

class TrtLlmFp8ExpertsMonolithic:
    # ...

    def _supports_routing_method(self, routing_method):
        # Remove Renormalize and RenormalizeNaive from supported methods
        if routing_method in ["Renormalize", "RenormalizeNaive"]:
            return False
        # ... existing code ...

Alternatively, you can also modify the function as follows:

class TrtLlmFp8ExpertsMonolithic:
    # ...

    def _supports_routing_method(self, routing_method):
        supported_methods = [...]  # existing supported methods
        if routing_method in supported_methods:
            return routing_method not in ["Renormalize", "RenormalizeNaive"]
        return False

Verification

After applying the fix, verify that the accuracy of the affected models (Qwen3.5-35B-A3B-FP8 and Qwen3.5-122B-A10B-FP8) is restored to 75-80% with any DP+EP configuration.

Extra Tips

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [Bug]: FlashInfer TRTLLM monolithic MoE produces 0% accuracy for Qwen3.5-35B/122B FP8 [3 pull requests, 3 comments, 2 participants]