vllm - 💡(How to fix) Fix [Bug]: Mamba prefix caching + MTP speculative decoding crashes on startup for NemotronH models [4 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#39809Fetched 2026-04-16 06:36:32
View on GitHub
Comments
4
Participants
3
Timeline
10
Reactions
0
Author
Timeline (top)
commented ×4labeled ×2closed ×1mentioned ×1

Error Message

RuntimeError: The size of tensor a (6) must match the size of tensor b (7) at non-singleton dimension 1

Root Cause

Root cause: The pre-allocated cudagraph buffer for state_indices_tensor_d uses:

max_num_blocks = cdiv(max_model_len, block_size)

But the block table allocated by gpu_model_runner.py ~L6470 uses:

max_num_blocks_per_req = cdiv(max_model_len, block_size) + kv_cache_spec.num_speculative_blocks

Fix Action

Workaround

Disable MTP when prefix caching is enabled. For orchestrator workloads with 30k+ context, prefix caching eliminates full re-prefill every turn — a far larger latency win than MTP's ~10-15% decode throughput improvement.

Code Example

vllm serve nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
  --enable-prefix-caching \
  --kv-cache-dtype fp8 \
  --dtype bfloat16 \
  --trust-remote-code \
  --gpu-memory-utilization 0.90 \
  --max-model-len 21525 \
  --tensor-parallel-size 4 \
  --speculative-config '{"method": "mtp", "num_speculative_tokens": 1}'

---

RuntimeError: The size of tensor a (6) must match the size of tensor b (7) at non-singleton dimension 1

---

max_num_blocks = cdiv(max_model_len, block_size)

---

max_num_blocks_per_req = cdiv(max_model_len, block_size) + kv_cache_spec.num_speculative_blocks

---

max_num_blocks = cdiv(
    self.vllm_config.model_config.max_model_len,
    self.kv_cache_spec.block_size,
) + self.kv_cache_spec.num_speculative_blocks

---

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 512 (input tensor's size at dimension 0), but got split_sizes=[256, 0]

---

block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
    : metadata.num_decode_tokens
]

---

block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[:padded_bs]
block_idx_last_scheduled_token[metadata.num_decodes:] = 0

---

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

---

mamba_mixer2.py:926 mamba_mixer2
mamba_mixer2.py:874 conv_ssm_forward
mamba_ssm.py:423    selective_state_update
triton/runtime/jit.py:743 run

---

"""Patch vLLM 0.19.0 mamba_attn.py for prefix caching + MTP compatibility."""

PATH = "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mamba_attn.py"

with open(PATH) as f:
    code = f.read()

# Bug 1: Add num_speculative_blocks to max_num_blocks
old1 = """\
            max_num_blocks = cdiv(
                self.vllm_config.model_config.max_model_len,
                self.kv_cache_spec.block_size,
            )"""
new1 = old1 + " + self.kv_cache_spec.num_speculative_blocks"
assert old1 in code, "Bug 1 patch target not found"
code = code.replace(old1, new1)

# Bug 2: Fix block_idx slicing from num_decode_tokens to padded_bs
for var in ["block_idx_last_scheduled_token", "block_idx_last_computed_token"]:
    old = f"                {var} = self.{var}[\n                    : metadata.num_decode_tokens\n                ]"
    new = f"                {var} = self.{var}[\n                    :padded_bs\n                ]\n                {var}[metadata.num_decodes:] = 0"
    assert old in code, f"Bug 2 patch target not found for {var}"
    code = code.replace(old, new)

with open(PATH, "w") as f:
    f.write(code)

print("Patched mamba_attn.py")
RAW_BUFFERClick to expand / collapse

Your current environment

  • vLLM version: 0.19.0 (official Docker image vllm/vllm-openai:v0.19.0)
  • GPU: NVIDIA B200 (178 GB VRAM), tested TP=1 through TP=8
  • Model: nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 (NemotronH hybrid Mamba2-Transformer MoE)
  • Python: 3.12

Model/config

vllm serve nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 \
  --enable-prefix-caching \
  --kv-cache-dtype fp8 \
  --dtype bfloat16 \
  --trust-remote-code \
  --gpu-memory-utilization 0.90 \
  --max-model-len 21525 \
  --tensor-parallel-size 4 \
  --speculative-config '{"method": "mtp", "num_speculative_tokens": 1}'

Key flags: --enable-prefix-caching + --speculative-config with MTP on a hybrid model that declares SupportsMambaPrefixCaching.

🐛 Describe the bug

Enabling both prefix caching (mamba_cache_mode="all") and MTP speculative decoding on NemotronH (hybrid Mamba2 architecture) causes three cascading bugs during cudagraph profiling at startup. Each bug manifests only after patching the previous one.

Without MTP, prefix caching works correctly. All Nano-30B configs (no MTP) pass with 120/120 valid results. The bugs are specific to the combination of SupportsMambaPrefixCaching + num_speculative_tokens > 0.

Bug 1: state_indices_tensor_d column count mismatch

File: vllm/v1/attention/backends/mamba_attn.py, __init__ ~L110

Error:

RuntimeError: The size of tensor a (6) must match the size of tensor b (7) at non-singleton dimension 1

Root cause: The pre-allocated cudagraph buffer for state_indices_tensor_d uses:

max_num_blocks = cdiv(max_model_len, block_size)

But the block table allocated by gpu_model_runner.py ~L6470 uses:

max_num_blocks_per_req = cdiv(max_model_len, block_size) + kv_cache_spec.num_speculative_blocks

The + num_speculative_blocks accounts for MTP draft tokens. The pre-allocated buffer is missing this term, so the runtime block table has 1 extra column.

Fix:

max_num_blocks = cdiv(
    self.vllm_config.model_config.max_model_len,
    self.kv_cache_spec.block_size,
) + self.kv_cache_spec.num_speculative_blocks

Note: The comment at L114 says "Speculative decoding not supported with prefix caching" — this bug confirms the combination was never tested.

Bug 2: block_idx tensors sliced by token count instead of request count

File: vllm/v1/attention/backends/mamba_attn.py, _update_metadata_for_cudagraph_capture ~L527-536

Error (only visible after fixing Bug 1):

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 512 (input tensor's size at dimension 0), but got split_sizes=[256, 0]

Root cause: block_idx_last_scheduled_token and block_idx_last_computed_token are sliced by metadata.num_decode_tokens:

block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
    : metadata.num_decode_tokens
]

These are per-request tensors (one block index per decode request), but num_decode_tokens is a per-token count. With MTP (num_speculative_tokens=1), num_decode_tokens = num_decodes * 2, making the output tensor 2x too large.

The consumer in mamba_mixer2.py ~L626 splits by [num_decodes, num_prefills] (per-request), which doesn't match.

Why this only manifests on large GPUs: On GPUs with ≥70 GB VRAM, max_num_seqs defaults to 1024 (vs 256 on smaller GPUs), making decode_cudagraph_max_bs = min(1024, 512) = 512. The pre-allocated buffer is 512 elements, and [:num_decode_tokens] returns all 512. On smaller GPUs, decode_cudagraph_max_bs = 256, and PyTorch clamps [:512] to [:256], masking the bug.

Fix: Slice by padded_bs (= metadata.num_reqs) and zero-fill padding rows:

block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[:padded_bs]
block_idx_last_scheduled_token[metadata.num_decodes:] = 0

This matches the pattern used for state_indices_tensor_d at L506-507 and num_accepted_tokens at L515-516 in the same function.

Bug 3: Triton kernel illegal memory access in selective_state_update

File: vllm/model_executor/layers/mamba/ops/mamba_ssm.py ~L423, called from mamba_mixer2.py ~L874

Error (only visible after fixing Bugs 1 and 2):

RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Stack trace:

mamba_mixer2.py:926 mamba_mixer2
mamba_mixer2.py:874 conv_ssm_forward
mamba_ssm.py:423    selective_state_update
triton/runtime/jit.py:743 run

Root cause: This appears to be a kernel-level incompatibility between the mamba prefix caching block indices and the speculative decode state management in the Triton SSM kernel. The Python-level tensor shapes are now correct (after Bugs 1 and 2 are fixed), but the kernel itself does not correctly handle the combination of prefix caching gather indices + MTP multi-token state updates.

We have not investigated the Triton kernel internals. This may require changes to the selective_state_update kernel to be aware of speculative tokens when prefix caching block indices are used.

Observed behavior matrix

ModelMTPPrefix CachingGPU VRAMResult
Nano-30B (no MTP)NoYes178 GB (B200)Works — 120/120 configs
Super-120BYes (1 token)No178 GB (B200)Works (v0.18.0 baseline)
Super-120BYes (1 token)Yes178 GB (B200)Bug 1Bug 2Bug 3
Super-120BNoYes178 GB (B200)Works (MTP disabled workaround)
Nano-30B (no MTP)NoYes48 GB (Ada)Works

Workaround

Disable MTP when prefix caching is enabled. For orchestrator workloads with 30k+ context, prefix caching eliminates full re-prefill every turn — a far larger latency win than MTP's ~10-15% decode throughput improvement.

Additional context

  • NemotronH declares SupportsMambaPrefixCaching in nemotron_h.py ~L792
  • The mamba_cache_mode is set to "all" automatically when prefix caching is enabled for models with this interface
  • The comment at mamba_attn.py L114 states "Speculative decoding not supported with prefix caching" — consistent with these findings
  • Bugs 1 and 2 are straightforward Python fixes (included above). Bug 3 requires Triton kernel investigation.

Patches for Bugs 1 and 2

<details> <summary>Python patch script (apply during Docker build)</summary>
"""Patch vLLM 0.19.0 mamba_attn.py for prefix caching + MTP compatibility."""

PATH = "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mamba_attn.py"

with open(PATH) as f:
    code = f.read()

# Bug 1: Add num_speculative_blocks to max_num_blocks
old1 = """\
            max_num_blocks = cdiv(
                self.vllm_config.model_config.max_model_len,
                self.kv_cache_spec.block_size,
            )"""
new1 = old1 + " + self.kv_cache_spec.num_speculative_blocks"
assert old1 in code, "Bug 1 patch target not found"
code = code.replace(old1, new1)

# Bug 2: Fix block_idx slicing from num_decode_tokens to padded_bs
for var in ["block_idx_last_scheduled_token", "block_idx_last_computed_token"]:
    old = f"                {var} = self.{var}[\n                    : metadata.num_decode_tokens\n                ]"
    new = f"                {var} = self.{var}[\n                    :padded_bs\n                ]\n                {var}[metadata.num_decodes:] = 0"
    assert old in code, f"Bug 2 patch target not found for {var}"
    code = code.replace(old, new)

with open(PATH, "w") as f:
    f.write(code)

print("Patched mamba_attn.py")
</details>

Related issues

IssueRelationship
#26201 — Tracking: Prefix Caching for Hybrid ModelsParent tracking issue. Explicitly lists "enabling compatibility between mamba prefix caching and speculative decoding" as a TODO. Confirms this combination is known-incomplete.
#30114 — Speculative decoding support for mamba modelsOlder V0 engine report that mamba + spec decode raises NotImplementedError. V1 engine partially supports it but with the bugs described here.
#35288 — MTP produces corrupted output at concurrency >= 4 (V1)Same V1 spec decode code path, same CUDA illegal memory access symptom. Affects Qwen3 MoE (not mamba), but likely a related class of state management bugs in the speculative decode path.
#39273 — Ngram spec decode corrupted output on hybrid GDN modelsSSM state rollback issue with speculative decoding on hybrid models. Directly relevant to Bug 3 — the Triton kernel doesn't handle state rollback when speculative tokens are rejected with prefix caching block indices.
#38182 — MTP reduces prefix cache hit rate on Qwen3.5-35BMTP + prefix caching interaction on a different hybrid model (Qwen3.5-35B-A3B). Different symptom (degraded hit rate vs crash) but same feature combination.
#35031 — MTP with NVFP4: Weight Shape MismatchMTP + NVFP4 quantization weight loading issue. Same model family (NVFP4) and spec decode method (MTP).

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

TL;DR

To fix the issue, apply the provided patches for Bugs 1 and 2, and consider disabling MTP when prefix caching is enabled as a workaround, while further investigation is needed for Bug 3.

Guidance

  • Apply the patches for Bugs 1 and 2 using the provided Python patch script to ensure compatibility between prefix caching and MTP.
  • Disable MTP when prefix caching is enabled as a temporary workaround to avoid the bugs.
  • Investigate the Triton kernel internals to resolve Bug 3, which appears to be a kernel-level incompatibility between mamba prefix caching block indices and speculative decode state management.
  • Review related issues, such as #39273, to understand potential similarities in state management bugs with speculative decoding on hybrid models.

Example

The patch for Bug 1 involves modifying the max_num_blocks calculation to include num_speculative_blocks:

max_num_blocks = cdiv(
    self.vllm_config.model_config.max_model_len,
    self.kv_cache_spec.block_size,
) + self.kv_cache_spec.num_speculative_blocks

Similarly, the patch for Bug 2 fixes the slicing of block_idx tensors:

block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[:padded_bs]
block_idx_last_scheduled_token[metadata.num_decodes:] = 0

Notes

The provided patches only address Bugs 1 and 2, and further investigation is required to resolve Bug 3. Disabling MTP when prefix caching is enabled may be a viable workaround, but it may impact performance.

Recommendation

Apply the workaround of disabling MTP when prefix caching is enabled, as it is a simpler and more immediate solution, while the investigation into Bug 3 and the Triton kernel internals continues.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - 💡(How to fix) Fix [Bug]: Mamba prefix caching + MTP speculative decoding crashes on startup for NemotronH models [4 comments, 3 participants]