vllm - ✅(Solved) Fix mm_fp4 trtllm backend leaks padding scales into real rows (use_8x4_sf_layout=True) [1 pull requests, 2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37563Fetched 2026-04-08 01:04:36
View on GitHub
Comments
2
Participants
1
Timeline
5
Reactions
0
Author
Participants
Timeline (top)
commented ×2referenced ×2cross-referenced ×1

When mm_fp4 is called with backend="trtllm" and use_8x4_sf_layout=True (triggered when m <= 32), the kernel reads scale factors from padding rows and mixes them into real rows' output. This causes:

  1. NaN contamination: If padding scales contain NaN (e.g. from stale torch.empty memory under CUDA graphs), real rows' output becomes NaN.
  2. Silent numerical inaccuracy: Even with finite padding scale values, real rows' output changes (max_diff ~500-700 in bf16).

The CUTLASS backend does not have this issue — it perfectly isolates rows within a tile.

Root Cause

The TRT-LLM 8×4 scale factor layout path in FlashInfer's mm_fp4 kernel reads scales from padding rows (rows beyond the actual m tokens but within the padded tile) and applies them to real rows' computation. This is a kernel-level bug in FlashInfer.

Fix Action

Workaround

Setting padding scales to zero (0x00 in float8_e4m3fn) neutralizes the bug: even though the kernel still reads the wrong scales, 0 * data = 0 contributes nothing to real rows' output. This is the same principle behind the torch.zeros fix in create_fp4_scale_tensor for the CUTLASS MoE path.

PR fix notes

PR #37564: [Bugfix] Zero-init NVFP4 padding scales to prevent NaN contamination

Description (problem / solution / changelog)

Summary

  • Zero-initialize NVFP4 scale tensors (torch.zeros instead of torch.empty) at three allocation sites to prevent NaN/garbage from stale memory leaking into real rows' output

Problem

When NVFP4 scale tensors are allocated with torch.empty, padding rows (beyond actual token count but within the tile-aligned allocation) contain uninitialized memory. Under CUDA graphs, this stale memory can contain NaN bytes.

The mm_fp4 TRT-LLM backend with use_8x4_sf_layout=True (_custom_ops.py:1732) reads padding scales and applies them to real rows, causing:

  1. NaN contamination when padding scales contain NaN
  2. Silent numerical errors (max_diff ~500-700 in bf16) even with finite padding scale values

See #37563 for detailed reproduction and test results.

Fix

Replace torch.emptytorch.zeros for all NVFP4 scale allocations:

  • create_fp4_scale_tensor — attention linear layers
  • scaled_fp4_experts_quant — MoE Gemm1 input quantization
  • silu_and_mul_scaled_fp4_experts_quant — MoE Gemm2 input quantization

Zero scales ensure that even if a kernel incorrectly reads padding rows, the contribution is 0 * data = 0, which is harmless.

Test plan

  • Verify NaN contamination test passes (padding NaN no longer leaks into real rows)
  • Verify finite contamination test shows SAME for all backends
  • Run existing test_flashinfer_nvfp4_scaled_mm tests
  • Benchmark to confirm no perf regression from torch.zeros vs torch.empty

Closes #37563

🤖 Generated with Claude Code

Changed files

  • vllm/_custom_ops.py (modified, +9/-2)

Code Example

import torch
from vllm._custom_ops import scaled_fp4_quant
from flashinfer import mm_fp4

TILE_ROWS = 128
HIDDEN_DIM = 512
OUT_DIM = 256
BLOCK_SCALE_SIZE = 16

def _round_up(x, multiple):
    return ((x + multiple - 1) // multiple) * multiple

def test(num_real_rows, backend):
    global_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
    padded_rows = _round_up(num_real_rows + 1, TILE_ROWS)

    X_pad_zero = torch.zeros(padded_rows, HIDDEN_DIM, dtype=torch.bfloat16, device="cuda")
    X_pad_zero[:num_real_rows] = torch.randn(
        num_real_rows, HIDDEN_DIM, dtype=torch.bfloat16, device="cuda"
    )
    X_pad_448 = X_pad_zero.clone()
    X_pad_448[num_real_rows:] = 448.0

    fp4_data, scales_zero = scaled_fp4_quant(
        X_pad_zero, global_scale, is_sf_swizzled_layout=True, backend=backend
    )
    _, scales_448 = scaled_fp4_quant(
        X_pad_448, global_scale, is_sf_swizzled_layout=True, backend=backend
    )

    padding_mask = (
        scales_zero.view(torch.uint8) != scales_448.view(torch.uint8)
    )

    scales_poisoned = scales_zero.view(torch.uint8).clone()
    scales_poisoned[padding_mask] = 0x7F  # NaN in float8_e4m3fn

    weight_fp4 = torch.randint(
        0, 255, (OUT_DIM, HIDDEN_DIM // 2), dtype=torch.uint8, device="cuda"
    )
    weight_scales = torch.ones(
        OUT_DIM, HIDDEN_DIM // BLOCK_SCALE_SIZE,
        dtype=torch.float8_e4m3fn, device="cuda",
    )
    alpha = torch.tensor([1.0], dtype=torch.float32, device="cuda")
    use_8x4 = (backend == "trtllm" and num_real_rows <= 32)

    output = mm_fp4(
        fp4_data,
        weight_fp4.t(),
        scales_poisoned.view(torch.float8_e4m3fn),
        weight_scales.t(),
        alpha,
        torch.bfloat16,
        block_size=BLOCK_SCALE_SIZE,
        use_8x4_sf_layout=use_8x4,
        backend=backend,
    )

    nan_rows = torch.isnan(output[:num_real_rows]).any(dim=-1).sum().item()
    tag = "FAIL" if nan_rows else "OK"
    print(f"{tag}  m={num_real_rows}  backend={backend}  nan_rows={nan_rows}/{num_real_rows}")

for m in [1, 2, 4, 8, 16, 32, 64]:
    for backend in ["cutlass", "trtllm"]:
        test(m, backend)

---

OK    m=  1  backend=cutlass  nan_rows=0/1
FAIL  m=  1  backend=trtllm   nan_rows=1/1
OK    m=  2  backend=cutlass  nan_rows=0/2
FAIL  m=  2  backend=trtllm   nan_rows=2/2
OK    m=  4  backend=cutlass  nan_rows=0/4
FAIL  m=  4  backend=trtllm   nan_rows=4/4
OK    m=  8  backend=cutlass  nan_rows=0/8
FAIL  m=  8  backend=trtllm   nan_rows=8/8
OK    m= 16  backend=cutlass  nan_rows=0/16
FAIL  m= 16  backend=trtllm   nan_rows=14/16
OK    m= 32  backend=cutlass  nan_rows=0/32
FAIL  m= 32  backend=trtllm   nan_rows=24/32
OK    m= 64  backend=cutlass  nan_rows=0/64
OK    m= 64  backend=trtllm   nan_rows=0/64

---

SAME  m=  1  backend=cutlass  max_diff=0.000000
DIFF  m=  1  backend=trtllm   max_diff=514.500000
SAME  m=  2  backend=cutlass  max_diff=0.000000
DIFF  m=  2  backend=trtllm   max_diff=692.687500
...
SAME  m= 64  backend=cutlass  max_diff=0.000000
SAME  m= 64  backend=trtllm   max_diff=0.000000
RAW_BUFFERClick to expand / collapse

Bug: TRT-LLM mm_fp4 with use_8x4_sf_layout=True reads padding rows' scales and applies them to real rows

Summary

When mm_fp4 is called with backend="trtllm" and use_8x4_sf_layout=True (triggered when m <= 32), the kernel reads scale factors from padding rows and mixes them into real rows' output. This causes:

  1. NaN contamination: If padding scales contain NaN (e.g. from stale torch.empty memory under CUDA graphs), real rows' output becomes NaN.
  2. Silent numerical inaccuracy: Even with finite padding scale values, real rows' output changes (max_diff ~500-700 in bf16).

The CUTLASS backend does not have this issue — it perfectly isolates rows within a tile.

Root cause

The TRT-LLM 8×4 scale factor layout path in FlashInfer's mm_fp4 kernel reads scales from padding rows (rows beyond the actual m tokens but within the padded tile) and applies them to real rows' computation. This is a kernel-level bug in FlashInfer.

Reproduction

The test uses swizzle-aware padding detection: quantize the same real rows with two different padding fillers (zeros vs 448.0), diff the resulting scale bytes to find which bytes correspond to padding rows in the swizzled layout, then poison only those bytes.

NaN contamination test:

import torch
from vllm._custom_ops import scaled_fp4_quant
from flashinfer import mm_fp4

TILE_ROWS = 128
HIDDEN_DIM = 512
OUT_DIM = 256
BLOCK_SCALE_SIZE = 16

def _round_up(x, multiple):
    return ((x + multiple - 1) // multiple) * multiple

def test(num_real_rows, backend):
    global_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
    padded_rows = _round_up(num_real_rows + 1, TILE_ROWS)

    X_pad_zero = torch.zeros(padded_rows, HIDDEN_DIM, dtype=torch.bfloat16, device="cuda")
    X_pad_zero[:num_real_rows] = torch.randn(
        num_real_rows, HIDDEN_DIM, dtype=torch.bfloat16, device="cuda"
    )
    X_pad_448 = X_pad_zero.clone()
    X_pad_448[num_real_rows:] = 448.0

    fp4_data, scales_zero = scaled_fp4_quant(
        X_pad_zero, global_scale, is_sf_swizzled_layout=True, backend=backend
    )
    _, scales_448 = scaled_fp4_quant(
        X_pad_448, global_scale, is_sf_swizzled_layout=True, backend=backend
    )

    padding_mask = (
        scales_zero.view(torch.uint8) != scales_448.view(torch.uint8)
    )

    scales_poisoned = scales_zero.view(torch.uint8).clone()
    scales_poisoned[padding_mask] = 0x7F  # NaN in float8_e4m3fn

    weight_fp4 = torch.randint(
        0, 255, (OUT_DIM, HIDDEN_DIM // 2), dtype=torch.uint8, device="cuda"
    )
    weight_scales = torch.ones(
        OUT_DIM, HIDDEN_DIM // BLOCK_SCALE_SIZE,
        dtype=torch.float8_e4m3fn, device="cuda",
    )
    alpha = torch.tensor([1.0], dtype=torch.float32, device="cuda")
    use_8x4 = (backend == "trtllm" and num_real_rows <= 32)

    output = mm_fp4(
        fp4_data,
        weight_fp4.t(),
        scales_poisoned.view(torch.float8_e4m3fn),
        weight_scales.t(),
        alpha,
        torch.bfloat16,
        block_size=BLOCK_SCALE_SIZE,
        use_8x4_sf_layout=use_8x4,
        backend=backend,
    )

    nan_rows = torch.isnan(output[:num_real_rows]).any(dim=-1).sum().item()
    tag = "FAIL" if nan_rows else "OK"
    print(f"{tag}  m={num_real_rows}  backend={backend}  nan_rows={nan_rows}/{num_real_rows}")

for m in [1, 2, 4, 8, 16, 32, 64]:
    for backend in ["cutlass", "trtllm"]:
        test(m, backend)

Test results

NaN contamination (padding scales = NaN):

OK    m=  1  backend=cutlass  nan_rows=0/1
FAIL  m=  1  backend=trtllm   nan_rows=1/1
OK    m=  2  backend=cutlass  nan_rows=0/2
FAIL  m=  2  backend=trtllm   nan_rows=2/2
OK    m=  4  backend=cutlass  nan_rows=0/4
FAIL  m=  4  backend=trtllm   nan_rows=4/4
OK    m=  8  backend=cutlass  nan_rows=0/8
FAIL  m=  8  backend=trtllm   nan_rows=8/8
OK    m= 16  backend=cutlass  nan_rows=0/16
FAIL  m= 16  backend=trtllm   nan_rows=14/16
OK    m= 32  backend=cutlass  nan_rows=0/32
FAIL  m= 32  backend=trtllm   nan_rows=24/32
OK    m= 64  backend=cutlass  nan_rows=0/64
OK    m= 64  backend=trtllm   nan_rows=0/64

Finite contamination (padding scales = 1.0 vs original):

SAME  m=  1  backend=cutlass  max_diff=0.000000
DIFF  m=  1  backend=trtllm   max_diff=514.500000
SAME  m=  2  backend=cutlass  max_diff=0.000000
DIFF  m=  2  backend=trtllm   max_diff=692.687500
...
SAME  m= 64  backend=cutlass  max_diff=0.000000
SAME  m= 64  backend=trtllm   max_diff=0.000000

Key observations

  • CUTLASS backend: perfectly clean in all cases — no cross-row contamination
  • TRT-LLM, m ≤ 32 (use_8x4_sf_layout=True): broken — padding scales leak into real rows
  • TRT-LLM, m > 32 (use_8x4_sf_layout=False): clean — bug is specific to 8×4 SF layout path
  • The bug persists even when scales are produced by the correct flashinfer_quant_nvfp4_8x4_sf_layout quantization function (passing backend="trtllm" to scaled_fp4_quant)

Workaround

Setting padding scales to zero (0x00 in float8_e4m3fn) neutralizes the bug: even though the kernel still reads the wrong scales, 0 * data = 0 contributes nothing to real rows' output. This is the same principle behind the torch.zeros fix in create_fp4_scale_tensor for the CUTLASS MoE path.

Production impact

In R1 NVFP4 decode, production traces show all mm_fp4 GEMMs use the CUTLASS backend (which is unaffected). The TRT-LLM backend is only used for cvt_fp16_to_fp4_expert (quantization, not GEMM). So this bug does not currently affect R1 production, but it would affect any deployment that routes through flashinfer_scaled_fp4_mm with backend="trtllm" and small batch sizes (m ≤ 32).

Relevant code

Environment

  • GB200 (SM 100, Blackwell)
  • vLLM main branch

extent analysis

Fix Plan

To fix the bug, we need to modify the mm_fp4 kernel in FlashInfer to correctly handle the 8x4 scale factor layout when use_8x4_sf_layout=True.

Here are the steps:

  • Identify the lines of code in the mm_fp4 kernel where the scale factors are read from memory.
  • Modify the indexing to only read scale factors for real rows, excluding padding rows.
  • Verify that the modified kernel produces the correct results for both use_8x4_sf_layout=True and use_8x4_sf_layout=False.

Example code snippet:

# Assuming 'scales' is the tensor of scale factors and 'num_real_rows' is the number of real rows
# Calculate the number of padding rows
padding_rows = _round_up(num_real_rows + 1, TILE_ROWS) - num_real_rows

# Modify the indexing to exclude padding rows
real_row_scales = scales[:-padding_rows]

In the flashinfer_scaled_fp4_mm wrapper, add a check for use_8x4_sf_layout=True and modify the scales tensor to zero out the padding rows:

if use_8x4_sf_layout:
    # Calculate the number of padding rows
    padding_rows = _round_up(num_real_rows + 1, TILE_ROWS) - num_real_rows
    
    # Zero out the padding rows in the scales tensor
    scales[-padding_rows:] = 0.0

Verification

To verify the fix, run the test cases provided in the issue body with the modified mm_fp4 kernel and flashinfer_scaled_fp4_mm wrapper. The test cases should pass without any NaN contamination or silent numerical inaccuracy.

Extra Tips

  • Make sure to test the modified kernel with different batch sizes and input shapes to ensure that the fix is robust.
  • Consider adding additional test cases to cover different scenarios and edge cases.
  • If possible, try to reproduce the issue in a minimal example to better understand the root cause of the bug.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING