vllm - ✅(Solved) Fix [CI Failure]: PyTorch Compilation Passes Unit Tests [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#40622Fetched 2026-04-23 07:23:49
View on GitHub
Comments
2
Participants
2
Timeline
11
Reactions
0
Timeline (top)
labeled ×3added_to_project_v2 ×2commented ×2cross-referenced ×1

Error Message

The fix_functionalization test is flaky, as it compares outputs. It's possible this is an actual error as it happens often. It's also possible this is due to different semantics of RMSNorm since the partial IR migration (rms_norm, not fused_add_rms_norm), if the test uses RMSNorm instead of the kernel call directly.

Root Cause

  • Flaky test
  • Can reproduce locally
  • Caused by external libraries (e.g. bug in transformers)

Fix Action

Fixed

PR fix notes

PR #40629: [Bugfix][CI] Fix wrong residual shape in TestFusedAddRMSNorm.example_inputs that causes flaky test

Description (problem / solution / changelog)

Latest Status Update [24 Apr]

Thanks ProExpertProg for approving this PR! added repro script and is ready to merge

Purpose

Closes https://github.com/vllm-project/vllm/issues/40622.

TestFusedAddRMSNorm.example_inputs was generating a residual tensor with shape (128, hidden_size=16), but the model's forward feeds it into RMSNorm(intermediate_size=32) in the forward pass mm = torch.mm(view, gate_proj.T). So when fused_add_rms_norm kernel tries to read the tensor, it will read as if it had intermediate_size=32 columns and trigger out of memory error. The fix aligns residual to self.intermediate_size so both mm and residual have shape (128, 32) when they reach RMSNorm.

The test is intermittent because when reading beyond existing tensor, sometimes by chance the content after model_func and model_no_func happens to be close, so it can pass assert_close. But adding bounds check is too defensive so we don't add it here.

To Deterministically Reproduce this Bug

Since the test only allocates (128, 16) of residual tensor but fused_add_rms_norm kernel tries to read (128, 32) elements, just set different values at the back when runningmodel_fun and model_no_func, then the test will diverge and fail.

Run the script below,

<details>
"""Deterministic reproduction of the flaky test fixed by PR #40629.

BACKGROUND
----------
tests/compile/passes/test_functionalization.py::TestFusedAddRMSNorm builds:

    norm          = RMSNorm(intermediate_size=32, ...)    # expects 32 cols
    residual      = torch.randn(128, hidden_size=16)      # only 16 cols  <-- BUG
    mm            = torch.mm((128, 16), (16, 32)) -> (128, 32)
    norm(mm, residual)  # -> fused_add_rms_norm kernel

The `fused_add_rms_norm` CUDA kernel iterates over 32 elements per row
(because that is RMSNorm's hidden dim), but `residual` only has 16 valid
columns of storage per row. On every row, the kernel reads 16 fp16s of
whatever happens to live past residual's allocation. The test is flaky
because the two runs it compares (`model_func` vs `model_no_func`) each
allocate a fresh residual whose trailing bytes come from whatever the
CUDA caching allocator returned.

HOW THIS SCRIPT TRIGGERS IT DETERMINISTICALLY
---------------------------------------------
Instead of relying on the allocator, we back `residual` with an
over-sized storage and poison the tail with a known pattern:

    storage (size N*INTER): [ valid residual rows | pattern_value ... ]
                              ^ first N*HIDDEN      ^ N*(INTER-HIDDEN)
                                are always the same  -- what the kernel
                                (same seed every run) reads OOB

The in-bounds residual is byte-identical across runs; only the tail
pattern differs. If the output depends on the tail, the bug is real.

We run the same experiment twice:
  CASE A ("buggy shape"): residual declared as (128, HIDDEN=16)
  CASE B ("fixed shape"): residual declared as (128, HIDDEN=32 == INTER)

In CASE B there is no OOB region for the kernel to read, so the poison
pattern is irrelevant and the two runs must agree exactly.
"""

import sys

import torch

from vllm.config import (
    CompilationConfig,
    ModelConfig,
    VllmConfig,
    set_current_vllm_config,
)
from vllm.model_executor.layers.layernorm import RMSNorm

N = 128           # batch_size * seq_len
INTER = 32        # RMSNorm's hidden dim (what the kernel iterates over)
BUGGY_HIDDEN = 16  # residual's wrong declared shape (the bug)
FIXED_HIDDEN = 32  # residual's correct shape (PR #40629's fix)


def make_residual(pattern_value, hidden, dtype):
    """Build a contiguous (N, hidden) residual whose storage is size N*INTER.
    The first N*hidden elements are a fixed random residual; the remaining
    N*(INTER-hidden) elements are filled with `pattern_value`. When the
    fused_add_rms_norm kernel reads (N, INTER) elements from this pointer,
    it will see pattern_value in the OOB region (if hidden < INTER)."""
    storage = torch.full((N * INTER,), pattern_value, dtype=dtype, device="cuda")
    residual = torch.empty(0, dtype=dtype, device="cuda")
    residual.set_(storage.untyped_storage(), 0, (N, hidden), (hidden, 1))
    torch.manual_seed(42)
    residual.copy_(torch.randn((N, hidden), dtype=dtype, device="cuda"))
    return residual, storage  # keep storage alive


def run_once(pattern_value, hidden, dtype):
    """Build the TestFusedAddRMSNorm forward pass and run it once with the
    given residual-tail poison pattern."""
    torch.manual_seed(1)
    hidden_states = torch.randn((N, BUGGY_HIDDEN), dtype=dtype, device="cuda")
    torch.manual_seed(2)
    gate_proj = torch.empty((INTER, BUGGY_HIDDEN), dtype=dtype, device="cuda")
    torch.nn.init.normal_(gate_proj, std=0.02)

    norm = RMSNorm(INTER, 1e-5).to(device="cuda", dtype=dtype)
    norm.weight = torch.nn.Parameter(torch.ones(INTER, dtype=dtype, device="cuda"))

    residual, _anchor = make_residual(pattern_value, hidden, dtype)

    view = hidden_states.reshape(-1, BUGGY_HIDDEN)
    mm = torch.mm(view, gate_proj.permute(1, 0))  # (N, INTER)
    out, res_out = norm(mm, residual)
    return out.detach().clone(), res_out.detach().clone()


def measure_divergence(label, hidden, dtype):
    """Run twice with different poison patterns; return the max abs diff."""
    out_a, res_a = run_once(pattern_value=0.0, hidden=hidden, dtype=dtype)
    out_b, res_b = run_once(pattern_value=100.0, hidden=hidden, dtype=dtype)
    diff_out = (out_a.float() - out_b.float()).abs().max().item()
    diff_res = (res_a.float() - res_b.float()).abs().max().item()
    print(f"[{label}]  residual shape = (N={N}, hidden={hidden}), "
          f"kernel iterates over {INTER}")
    print(f"           max |out_A - out_B|  = {diff_out:.6f}   "
          f"(A/B differ only in OOB tail: 0.0 vs 100.0)")
    print(f"           max |res_A - res_B|  = {diff_res:.6f}")
    return diff_out, diff_res


def main():
    torch.set_default_device("cuda")
    dtype = torch.float16
    cfg = VllmConfig(
        model_config=ModelConfig(dtype=dtype),
        compilation_config=CompilationConfig(custom_ops=["all"]),
    )
    with set_current_vllm_config(cfg):
        print("=" * 70)
        print("CASE A: BUGGY shape (hidden=16, kernel expects 32)")
        print("        -> kernel reads 16 OOB fp16 per row; tail is poisoned.")
        print("=" * 70)
        buggy_diff_out, _ = measure_divergence("BUGGY", BUGGY_HIDDEN, dtype)
        print()
        print("=" * 70)
        print("CASE B: FIXED shape (hidden=32, matches kernel)")
        print("        -> no OOB; poison pattern is irrelevant.")
        print("=" * 70)
        fixed_diff_out, fixed_diff_res = measure_divergence("FIXED", FIXED_HIDDEN, dtype)

    print()
    print("=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print(f"  buggy (hidden=16): out diverges by {buggy_diff_out:.6f}  "
          f"<- depends on uninitialized memory")
    print(f"  fixed (hidden=32): out diverges by {fixed_diff_out:.6f}  "
          f"<- deterministic")
    assert buggy_diff_out > 1e-3, (
        "Expected buggy case to diverge due to OOB read; got identical outputs."
    )
    assert fixed_diff_out == 0.0 and fixed_diff_res == 0.0, (
        "Expected fixed case to be bit-exact across poison patterns; "
        f"got out={fixed_diff_out}, res={fixed_diff_res}."
    )
    print("\nCONFIRMED: the flakiness is caused by residual's shape being "
          "smaller than RMSNorm's hidden dim.")
    return 0


if __name__ == "__main__":
    sys.exit(main())
</details>

You will see this output that confirms this bug,

<details>
 ======================================================================
 CASE A: BUGGY shape (hidden=16, kernel expects 32)                                                                                     
         -> kernel reads 16 OOB fp16 per row; tail is poisoned.
 ======================================================================                                                                 
 [BUGGY]  residual shape = (N=128, hidden=16), kernel iterates over 32                                                                  
            max |out_A - out_B|  = 4.016113   (A/B differ only in OOB tail: 0.0 vs 100.0)
            max |res_A - res_B|  = 0.000000                                                                                             
                                                           
 ======================================================================                                                                 
 CASE B: FIXED shape (hidden=32, matches kernel)           
         -> no OOB; poison pattern is irrelevant.                                                                                       
 ======================================================================
 [FIXED]  residual shape = (N=128, hidden=32), kernel iterates over 32                                                                  
            max |out_A - out_B|  = 0.000000   (A/B differ only in OOB tail: 0.0 vs 100.0)                                               
            max |res_A - res_B|  = 0.000000
                                                                                                                                        
 ======================================================================
 SUMMARY                                                                                                                                
 ======================================================================
   buggy (hidden=16): out diverges by 4.016113  <- depends on uninitialized memory
   fixed (hidden=32): out diverges by 0.000000  <- deterministic
                                                                                                                                        
 CONFIRMED: the flakiness is caused by residual's shape being smaller than RMSNorm's hidden dim.
</details>

Test Plan

On this branch, run these 2 commands:

.venv/bin/python pytest tests/compile/passes/test_functionalization.py::test_fix_functionalization -k "TestFusedAddRMSNorm" -v

.venv/bin/python -m pytest tests/compile/passes/test_functionalization.py -v

Test Result

The frst one that only tests TestFusedAddRMSNorm (H100, CUDA 13.0, torch 2.11.0+cu130) passes:

<details>
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype0]  PASSED [ 25%]             
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype1]  PASSED [ 50%]             
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-False-dtype0] PASSED [ 75%]             
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-False-dtype1] PASSED [100%]             
                                                    
  ================ 4 passed, 10 deselected, 17 warnings in 25.40s ================
</details>

The second one test_functionalization.py also passes

<details>
tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestSiluMul-True-dtype0]                        PASSED [ 7%]                                                                                                                                    
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestSiluMul-True-dtype1]                        PASSED [ 14%]                                                                                                                                   
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestSiluMul-False-dtype0]                       PASSED [ 21%]                                                                                                                                   
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestSiluMul-False-dtype1]                       PASSED [ 28%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype0]                PASSED [ 35%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype1]                PASSED [ 42%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-False-dtype0]               PASSED [ 50%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-False-dtype1]               PASSED [ 57%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestRotaryEmbedding-False-dtype0]               PASSED [ 64%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestRotaryEmbedding-False-dtype1]               PASSED [ 71%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestRotaryEmbeddingSliceScatter-False-dtype0]   PASSED [ 78%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestRotaryEmbeddingSliceScatter-False-dtype1]   PASSED [ 85%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFunctionWithMutatedArgsAndReturn-False-dtype0] PASSED [ 92%]                                              
  tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFunctionWithMutatedArgsAndReturn-False-dtype1] PASSED  [100%]                                            
                                                                                                                                         
  ======================= 14 passed, 24 warnings in 31.48s =======================
</details>
<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
</details>

Changed files

  • tests/compile/passes/test_functionalization.py (modified, +3/-3)
RAW_BUFFERClick to expand / collapse

Name of failing test

tests/compile/passes/test_functionalization.py::test_fix_functionalization[TestFusedAddRMSNorm-True-dtype0]

Basic information

  • Flaky test
  • Can reproduce locally
  • Caused by external libraries (e.g. bug in transformers)

🧪 Describe the failing test

The fix_functionalization test is flaky, as it compares outputs. It's possible this is an actual error as it happens often. It's also possible this is due to different semantics of RMSNorm since the partial IR migration (rms_norm, not fused_add_rms_norm), if the test uses RMSNorm instead of the kernel call directly.

📝 History of failing test

Last 1-3 weeks

CC List.

@zou3519

extent analysis

TL;DR

Investigate the usage of RMSNorm versus fused_add_rms_norm in the test_fix_functionalization test to determine if the flakiness is due to semantic differences.

Guidance

  • Review the test code in tests/compile/passes/test_functionalization.py to understand how RMSNorm is being used and compared.
  • Check if the test is using the rms_norm partial IR migration or the fused_add_rms_norm kernel call directly.
  • Investigate the history of changes to the test and related code in the last 1-3 weeks to see if any recent modifications could be causing the flakiness.
  • Consider adding additional logging or debugging statements to the test to help determine the cause of the flakiness.

Notes

The lack of information about the specific error message or failure mode makes it difficult to provide a more targeted solution. Further investigation is needed to determine the root cause of the issue.

Recommendation

Apply workaround: Investigate and potentially modify the test to use consistent semantics for RMSNorm and fused_add_rms_norm to reduce flakiness.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [CI Failure]: PyTorch Compilation Passes Unit Tests [1 pull requests, 2 comments, 2 participants]