pytorch - ✅(Solved) Fix [vLLM][inductor][triton] decompose_triton_kernel_wrapper_functional AssertionError under dynamic shapes [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181735Fetched 2026-04-29 06:11:10
View on GitHub
Comments
0
Participants
1
Timeline
175
Reactions
0
Author
Participants
Timeline (top)
mentioned ×71subscribed ×71unsubscribed ×22labeled ×9

Error Message

Traceback (most recent call last): File "/tmp/repro_inductor_bug.py", line 94, in <module> main() File "/tmp/repro_inductor_bug.py", line 88, in main out = compiled(x) ^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1038, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1053, in _compile_fx_inner raise InductorError(e, currentframe()).with_traceback( File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1037, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( ^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1798, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1344, in codegen_and_compile _recursive_post_grad_passes(gm, is_inference=is_inference) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 583, in _recursive_post_grad_passes post_grad_passes(gm, is_inference) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 357, in post_grad_passes ).apply_graph_pass(decompose_triton_kernel_wrapper_functional) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py", line 103, in apply_graph_pass return pass_fn(self.gm.graph) ^^^^^^^^^^^^^^^^^^^^^^ File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1255, in decompose_triton_kernel_wrapper_functional graph_pass.apply(graph) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 2063, in apply entry.apply(m, graph, node) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 1132, in apply self.handler(match, *match.args, **match.kwargs) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1253, in _ match.replace_by_example(decomp, flat_args, run_functional_passes=False) File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 316, in replace_by_example assert len(graph_with_eager_vals.graph.nodes) == len( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch._inductor.exc.InductorError: AssertionError:

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Fix Action

Fixed

PR fix notes

PR #41135: [Bugfix] fix inductor error for dpsk v4

Description (problem / solution / changelog)

Purpose

Fix https://github.com/vllm-project/vllm/issues/41106

also see https://github.com/pytorch/pytorch/issues/181735

Test Plan

vllm serve deepseek-ai/DeepSeek-V4-Flash   --trust-remote-code   --kv-cache-dtype fp8   --block-size 256   --enable-expert-parallel   --data-parallel-size 1 --tensor-parallel-size 8   --compilation-config '{"cudagraph_mode":"FULL_AND_PIECEWISE", "custom_ops":["all"]}'   --max-num-batched-tokens 8192   --max-model-len auto   --max-num-seqs 128   --gpu-memory-utilization 0.95   --reasoning-parser deepseek_v4 --port 7888 --no-enable-flashinfer-autotune

lm_eval --model local-completions --model_args "model=deepseek-ai/DeepSeek-V4-Flash,base_url=http://0.0.0.0:7888/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=40960" --tasks gsm8k --num_fewshot 5

Test Result

TasksVersionFiltern-shotMetricValueStderr
gsm8k3flexible-extract5exact_match0.9515±0.0059
strict-match5exact_match0.9515±0.0059

aime26: 100


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
</details>

Changed files

  • vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py (modified, +106/-36)

Code Example

# torch/_inductor/pattern_matcher.py
# NB: This assertion might not be true in general, but it is true for
# the two use cases we have
# (triton_kernel_wrapper_functional, auto_functionalized)
assert len(graph_with_eager_vals.graph.nodes) == len(
    replacement.graph.nodes
)

---

compute_required_storage_length(shape, strides, offset)
# = 1 + offset + sum((dim - 1) * stride for dim, stride in zip(shape, strides))

---

import torch
import triton
import triton.language as tl


@triton.jit
def _fill_kernel(
    buf_ptr,
    num_tokens,
    stride_t,
    stride_i,
    INNER: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= num_tokens:
        return
    for i in range(INNER):
        tl.store(buf_ptr + pid * stride_t + i * stride_i, 1.0)


def fn(x: torch.Tensor) -> torch.Tensor:
    num_tokens = x.shape[0]
    n_groups = 1            #  (1) size-1 leading dim
    inner = 8
    aligned_T = ((num_tokens + 3) // 4) * 4   # padded length, SymInt-derived

    # (2) as_strided view; stride[0] = inner * aligned_T is a SymInt expression.
    buf = torch.empty(
        n_groups * inner * aligned_T,
        dtype=torch.float32,
        device=x.device,
    ).as_strided(
        (n_groups, num_tokens, inner),
        (inner * aligned_T, 1, aligned_T),
    )

    # (3) triton kernel mutates `buf` -> goes into tensors_to_clone.
    _fill_kernel[(num_tokens,)](
        buf,
        num_tokens,
        stride_t=buf.stride(1),
        stride_i=buf.stride(2),
        INNER=inner,
    )
    return buf + x.sum()


def main() -> None:
    torch.set_default_device("cuda")
    compiled = torch.compile(fn, dynamic=True, fullgraph=True)
    for n in (64, 128, 256):
        x = torch.randn(n, 8)
        out = compiled(x)
        torch.cuda.synchronize()
        print(f"n={n} ok, sum={out.sum().item():.2f}")


if __name__ == "__main__":
    main()

---

Traceback (most recent call last):
  File "/tmp/repro_inductor_bug.py", line 94, in <module>
    main()
  File "/tmp/repro_inductor_bug.py", line 88, in main
    out = compiled(x)
          ^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1038, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1053, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1037, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1798, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1344, in codegen_and_compile
    _recursive_post_grad_passes(gm, is_inference=is_inference)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 583, in _recursive_post_grad_passes
    post_grad_passes(gm, is_inference)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 357, in post_grad_passes
    ).apply_graph_pass(decompose_triton_kernel_wrapper_functional)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py", line 103, in apply_graph_pass
    return pass_fn(self.gm.graph)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1255, in decompose_triton_kernel_wrapper_functional
    graph_pass.apply(graph)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 2063, in apply
    entry.apply(m, graph, node)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 1132, in apply
    self.handler(match, *match.args, **match.kwargs)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1253, in _
    match.replace_by_example(decomp, flat_args, run_functional_passes=False)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 316, in replace_by_example
    assert len(graph_with_eager_vals.graph.nodes) == len(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: AssertionError: 

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch._inductor.fx_passes.post_grad.decompose_triton_kernel_wrapper_functional calls Match.replace_by_example, which traces the decomp twice (once with the FakeTensors stored in node.meta["eager_input_vals"], once with the FakeTensors in node.meta["val"]) and then asserts the two FX graphs have the same number of nodes:

# torch/_inductor/pattern_matcher.py
# NB: This assertion might not be true in general, but it is true for
# the two use cases we have
# (triton_kernel_wrapper_functional, auto_functionalized)
assert len(graph_with_eager_vals.graph.nodes) == len(
    replacement.graph.nodes
)

This assertion is not universally true. The decomp body (triton_kernel_wrapper_functional_dense) clones every mutated tensor via clone_preserve_strides, which calls

compute_required_storage_length(shape, strides, offset)
# = 1 + offset + sum((dim - 1) * stride for dim, stride in zip(shape, strides))

For an as_strided view of shape (1, s, K) the leading term is (1 - 1) * sym_stride[0], i.e. 0 * SymInt. The two traces simplify this differently:

  • eager_input_vals trace: aten.sym_stride.int(buf, 0) returns a fresh SymInt node, 0 * sym_stride is kept as a mul FX node.
  • val trace: the same expression folds to constant 0, no mul node is emitted.

Result: 25 vs 24 nodes, the assert fires with an empty AssertionError.

Three ingredients are needed to trigger:

  1. dynamic shapes on (so stride[0] becomes a SymInt);
  2. an as_strided view with a size-1 leading dim;
  3. a triton kernel that mutates that view, so it ends up in tensors_to_clone and runs through clone_preserve_strides.

Here is a minimal reproduce example

import torch
import triton
import triton.language as tl


@triton.jit
def _fill_kernel(
    buf_ptr,
    num_tokens,
    stride_t,
    stride_i,
    INNER: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= num_tokens:
        return
    for i in range(INNER):
        tl.store(buf_ptr + pid * stride_t + i * stride_i, 1.0)


def fn(x: torch.Tensor) -> torch.Tensor:
    num_tokens = x.shape[0]
    n_groups = 1            # ← (1) size-1 leading dim
    inner = 8
    aligned_T = ((num_tokens + 3) // 4) * 4   # padded length, SymInt-derived

    # (2) as_strided view; stride[0] = inner * aligned_T is a SymInt expression.
    buf = torch.empty(
        n_groups * inner * aligned_T,
        dtype=torch.float32,
        device=x.device,
    ).as_strided(
        (n_groups, num_tokens, inner),
        (inner * aligned_T, 1, aligned_T),
    )

    # (3) triton kernel mutates `buf` -> goes into tensors_to_clone.
    _fill_kernel[(num_tokens,)](
        buf,
        num_tokens,
        stride_t=buf.stride(1),
        stride_i=buf.stride(2),
        INNER=inner,
    )
    return buf + x.sum()


def main() -> None:
    torch.set_default_device("cuda")
    compiled = torch.compile(fn, dynamic=True, fullgraph=True)
    for n in (64, 128, 256):
        x = torch.randn(n, 8)
        out = compiled(x)
        torch.cuda.synchronize()
        print(f"n={n} ok, sum={out.sum().item():.2f}")


if __name__ == "__main__":
    main()

Error logs

Traceback (most recent call last):
  File "/tmp/repro_inductor_bug.py", line 94, in <module>
    main()
  File "/tmp/repro_inductor_bug.py", line 88, in main
    out = compiled(x)
          ^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1038, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1053, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1037, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1798, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1344, in codegen_and_compile
    _recursive_post_grad_passes(gm, is_inference=is_inference)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 583, in _recursive_post_grad_passes
    post_grad_passes(gm, is_inference)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 357, in post_grad_passes
    ).apply_graph_pass(decompose_triton_kernel_wrapper_functional)
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/fx/passes/graph_transform_observer.py", line 103, in apply_graph_pass
    return pass_fn(self.gm.graph)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1255, in decompose_triton_kernel_wrapper_functional
    graph_pass.apply(graph)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 2063, in apply
    entry.apply(m, graph, node)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 1132, in apply
    self.handler(match, *match.args, **match.kwargs)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/fx_passes/post_grad.py", line 1253, in _
    match.replace_by_example(decomp, flat_args, run_functional_passes=False)
  File "/home/zjy/vllm/.venv/lib/python3.12/site-packages/torch/_inductor/pattern_matcher.py", line 316, in replace_by_example
    assert len(graph_with_eager_vals.graph.nodes) == len(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: AssertionError: 

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

https://github.com/vllm-project/vllm/issues/41106

PyTorch version : 2.11.0+cu130

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo @oulgen @davidberard98 @zou3519

extent analysis

TL;DR

The issue can be resolved by modifying the decompose_triton_kernel_wrapper_functional function to handle the case where the two FX graphs have different numbers of nodes due to the simplification of SymInt expressions.

Guidance

  • Identify the source of the discrepancy in node counts between the two FX graphs, which is caused by the simplification of SymInt expressions in the val trace.
  • Modify the decompose_triton_kernel_wrapper_functional function to account for this discrepancy, potentially by removing the assertion or adding a special case for SymInt expressions.
  • Verify that the modified function correctly handles the case where the two FX graphs have different numbers of nodes.
  • Test the modified function with the provided minimal reproduce example to ensure that it resolves the issue.

Example

# torch/_inductor/fx_passes/post_grad.py
def decompose_triton_kernel_wrapper_functional(graph):
    # ...
    graph_with_eager_vals = decompose_triton_kernel_wrapper_functional_inner(graph, eager_input_vals)
    replacement = decompose_triton_kernel_wrapper_functional_inner(graph, val)
    # Remove the assertion or add a special case for SymInt expressions
    # assert len(graph_with_eager_vals.graph.nodes) == len(replacement.graph.nodes)
    if len(graph_with_eager_vals.graph.nodes) != len(replacement.graph.nodes):
        # Handle the discrepancy, e.g., by logging a warning or adding a special case
        print("Warning: FX graphs have different numbers of nodes")
    # ...

Notes

The provided solution is a potential workaround, and the root cause of the issue may require further investigation. The modification to the decompose_triton_kernel_wrapper_functional function may have unintended consequences, and thorough testing is necessary to ensure that it resolves the issue without introducing new problems.

Recommendation

Apply the workaround by modifying the decompose_triton_kernel_wrapper_functional function

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [vLLM][inductor][triton] decompose_triton_kernel_wrapper_functional AssertionError under dynamic shapes [1 pull requests, 1 participants]