vllm - ✅(Solved) Fix [Bug]: GPT OSS Router GEMM Causing NaNs [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#38754Fetched 2026-04-08 02:22:53
View on GitHub
Comments
0
Participants
1
Timeline
2
Reactions
0
Participants
Timeline (top)
cross-referenced ×1labeled ×1

PR fix notes

PR #37205: [Kernel] Add gpt-oss Router GEMM kernel

Description (problem / solution / changelog)

Purpose

This PR add gpt-oss optimized Router GEMM kernel.

1% - 2% output token throughput improvement at batch size 1.

Test Plan

Added unit test.

pytest -s -v tests/kernels/moe/test_router_gemm.py

Test Result

Unit test passed.

Micro bench

python3 benchmarks/kernels/benchmark_router_gemm.py --model openai/gpt-oss-20b --max-batch-size 8192
openai/gpt-oss-20b router gemm throughput:
    batch_size  PyTorch (TFLOPs)  vLLM (TFLOPs)
0          1.0          0.031421       0.089570
1          2.0          0.064126       0.178294
2          4.0          0.128316       0.354806
3          8.0          0.255723       0.707328
4         16.0          0.489696       1.371641
5         32.0          0.968925       2.821921
6         64.0          1.911852       5.565304
7        128.0          3.720499      11.084153
8        256.0          7.380550      20.414985
9        512.0         14.636453      32.505290
10      1024.0         29.070115      33.019194
11      2048.0         56.395810      33.173499
12      4096.0        120.938349      34.309334
13      8192.0        152.726427      32.962142
python3 benchmarks/kernels/benchmark_router_gemm.py --model openai/gpt-oss-120b --max-batch-size 8192
openai/gpt-oss-120b router gemm throughput:
    batch_size  PyTorch (TFLOPs)  vLLM (TFLOPs)
0          1.0          0.123305       0.355313
1          2.0          0.254834       0.705630
2          4.0          0.505317       1.404891
3          8.0          1.004569       2.794942
4         16.0          1.918626       5.537719
5         32.0          3.781321      10.981122
6         64.0          7.482728      19.906920
7        128.0         14.560503      32.026677
8        256.0         28.719212      32.120986
9        512.0         56.351540      32.404140
10      1024.0        112.437718      32.985149
11      2048.0        201.149820      35.144443
12      4096.0        402.626180      37.998635
13      8192.0        452.585227      36.484008

gpt_oss_router_gemm kernel has better throughput for low batch size.

Benchmark

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --no-enable-prefix-caching
vllm bench serve \
        --model openai/gpt-oss-20b \
        --dataset-name sharegpt \
        --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
        --sharegpt-output-len 300 \
        --num-prompts ${num_prompts} \
        --max-concurrency ${concurrency} \
        --num-warmups 50 \
        --ignore-eos \
        --temperature 0

Main:

concurrency=1

============ Serving Benchmark Result ============
Successful requests:                     60        
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  87.73     
Total input tokens:                      15599     
Total generated tokens:                  18000     
Request throughput (req/s):              0.68      
Output token throughput (tok/s):         205.19    
Peak output token throughput (tok/s):    219.00    
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          383.00    
---------------Time to First Token----------------
Mean TTFT (ms):                          32.70     
Median TTFT (ms):                        28.32     
P99 TTFT (ms):                           80.66     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.78      
Median TPOT (ms):                        4.75      
P99 TPOT (ms):                           5.17      
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.78      
Median ITL (ms):                         4.74      
P99 ITL (ms):                            5.52      
==================================================

concurrency=16

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  444.34    
Total input tokens:                      217301    
Total generated tokens:                  288000    
Request throughput (req/s):              2.16      
Output token throughput (tok/s):         648.16    
Peak output token throughput (tok/s):    768.00    
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          1137.21   
---------------Time to First Token----------------
Mean TTFT (ms):                          190.85    
Median TTFT (ms):                        152.53    
P99 TTFT (ms):                           827.11    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.12     
Median TPOT (ms):                        23.99     
P99 TPOT (ms):                           26.65     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.12     
Median ITL (ms):                         23.72     
P99 ITL (ms):                            36.63     
==================================================

PR:

concurrency=1

============ Serving Benchmark Result ============
Successful requests:                     60        
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  86.24     
Total input tokens:                      15599     
Total generated tokens:                  18000     
Request throughput (req/s):              0.70      
Output token throughput (tok/s):         208.73    
Peak output token throughput (tok/s):    222.00    
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          389.62    
---------------Time to First Token----------------
Mean TTFT (ms):                          32.81     
Median TTFT (ms):                        28.40     
P99 TTFT (ms):                           80.30     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.70      
Median TPOT (ms):                        4.66      
P99 TPOT (ms):                           5.08      
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.70      
Median ITL (ms):                         4.66      
P99 ITL (ms):                            5.44      
==================================================

concurrency=16

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  442.22    
Total input tokens:                      217301    
Total generated tokens:                  288000    
Request throughput (req/s):              2.17      
Output token throughput (tok/s):         651.26    
Peak output token throughput (tok/s):    784.00    
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          1142.64   
---------------Time to First Token----------------
Mean TTFT (ms):                          174.88    
Median TTFT (ms):                        147.44    
P99 TTFT (ms):                           829.79    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.06     
Median TPOT (ms):                        24.02     
P99 TPOT (ms):                           26.06     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.06     
Median ITL (ms):                         23.67     
P99 ITL (ms):                            42.65     
==================================================

Accuracy Testing

OPENAI_API_KEY=EMPTY python3 -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 200 --reasoning-effort low

Main:

Writing report to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459.html
{'chars': np.float64(66.44823232323232), 'chars:std': np.float64(235.44711891411228), 'score': np.float64(0.5561868686868687), 'score:std': np.float64(0.49683300593576163)}
Writing results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459.json
Writing all results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459_allresults.json
[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459', 'metric': 0.5561868686868687}]

PR:

Writing report to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258.html
{'chars': np.float64(73.21843434343434), 'chars:std': np.float64(258.6924049276393), 'score': np.float64(0.5662878787878788), 'score:std': np.float64(0.49558643759268023)}
Writing results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258.json
Writing all results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258_allresults.json
[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258', 'metric': 0.5662878787878788}]

<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • CMakeLists.txt (modified, +1/-0)
  • benchmarks/kernels/benchmark_router_gemm.py (added, +134/-0)
  • csrc/moe/gpt_oss_router_gemm.cu (added, +144/-0)
  • csrc/moe/gpt_oss_router_gemm.cuh (added, +447/-0)
  • csrc/moe/moe_ops.h (modified, +4/-0)
  • csrc/moe/torch_bindings.cpp (modified, +6/-0)
  • tests/kernels/moe/test_router_gemm.py (added, +37/-0)
  • vllm/_custom_ops.py (modified, +13/-0)
  • vllm/lora/layers/__init__.py (modified, +2/-0)
  • vllm/lora/layers/gate_linear.py (added, +30/-0)
  • vllm/lora/utils.py (modified, +2/-0)
  • vllm/model_executor/layers/fused_moe/router/gate_linear.py (modified, +52/-6)
  • vllm/model_executor/models/gpt_oss.py (modified, +3/-7)
RAW_BUFFERClick to expand / collapse

Your current environment

<details>

reproduced on various commits since b1169d7be8add20ab1db4bc93c2b5c6336ef9754, on 1xB200. Multiple fresh installs, both cu129 and cu130

</details>

🐛 Describe the bug

The custom router GEMM kernel is causing infrequent NaNs in an internal test that poisons the correctness of the EAGLE draft model. It seems to only occur in very specific scenarios involving cuda graphs, prefix caching, and chunked prefills.

Identified root-cause by manual bisection: https://github.com/vllm-project/vllm/pull/37205, originally manifested as acceptance rates of EAGLE3 going to zero. No special flags needed to reproduce, just running gpt-oss-120b on 1xB200 with EAGLE3 enabled.

Unfortunately I cannot share the script to reproduce the failure, it is an internal script for multi-turn chat replay.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

TL;DR

The custom router GEMM kernel may need to be revised or patched to prevent infrequent NaNs when used with cuda graphs, prefix caching, and chunked prefills in specific scenarios.

Guidance

  • Review the changes made in pull request #37205 to understand the identified root cause and potential fixes.
  • Investigate the interaction between the custom router GEMM kernel, cuda graphs, prefix caching, and chunked prefills to identify the specific conditions that lead to NaNs.
  • Consider adding additional error checking or handling for NaNs in the custom router GEMM kernel to prevent poisoning of the EAGLE draft model.
  • Attempt to reproduce the issue using a simplified test case or script to further isolate the problem.

Notes

The lack of a reproducible script and specific details about the custom router GEMM kernel and EAGLE draft model may limit the ability to provide a comprehensive solution.

Recommendation

Apply workaround: Implement additional error checking or handling for NaNs in the custom router GEMM kernel to prevent poisoning of the EAGLE draft model, as the root cause has been identified and a potential fix is available in pull request #37205.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING