vllm - ✅(Solved) Fix [Bug]: `ROCM_AITER_MLA_SPARSE` prefill produces garbage for prompt_len > ~20K tokens on gfx950 (GLM-5.1-FP8) [1 pull requests, 3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#40018Fetched 2026-04-17 08:27:36
View on GitHub
Comments
3
Participants
3
Timeline
15
Reactions
0
Author
Timeline (top)
commented ×3mentioned ×3subscribed ×3labeled ×2

The investigation and this report was generated with the help of Claude Code.

The ROCM_AITER_MLA_SPARSE attention backend — the only backend available for glm_moe_dsa sparse-MLA models on ROCm — silently produces corrupt output when the prefill prompt exceeds approximately 20,000 tokens. This is a distinct bug from the decode-path regression reported in #39303 (which was fixed by PR #39509 reverting aiter to v0.1.10.post3). The new bug persists on the fixed nightly and is in the prefill path, not the decode path.

#39303 (decode, FIXED)This issue (prefill, NEW)
PathDecode (paged attention)Prefill (indexer gather + logits)
Kerneldeepgemm_fp8_paged_mqa_logits_stage1rocm_fp8_mqa_logits via rocm_aiter_mla_sparse.py
TriggerRunning decode context > 2048 tokensPrompt (prefill) length > ~20K tokens
aiter versionv0.1.12 only (fixed in v0.1.10.post3)v0.1.10.post3 — still broken
SymptomToken salad after ~2K generated tokensToken salad from first generated token
Root causeaiter kernel regression in v0.1.12Missing skip_kv_gather + workspace reuse in ROCm path

Root Cause

#39303 (decode, FIXED)This issue (prefill, NEW)
PathDecode (paged attention)Prefill (indexer gather + logits)
Kerneldeepgemm_fp8_paged_mqa_logits_stage1rocm_fp8_mqa_logits via rocm_aiter_mla_sparse.py
TriggerRunning decode context > 2048 tokensPrompt (prefill) length > ~20K tokens
aiter versionv0.1.12 only (fixed in v0.1.10.post3)v0.1.10.post3 — still broken
SymptomToken salad after ~2K generated tokensToken salad from first generated token
Root causeaiter kernel regression in v0.1.12Missing skip_kv_gather + workspace reuse in ROCm path

Fix Action

Fix / Workaround

The ROCM_AITER_MLA_SPARSE attention backend — the only backend available for glm_moe_dsa sparse-MLA models on ROCm — silently produces corrupt output when the prefill prompt exceeds approximately 20,000 tokens. This is a distinct bug from the decode-path regression reported in #39303 (which was fixed by PR #39509 reverting aiter to v0.1.10.post3). The new bug persists on the fixed nightly and is in the prefill path, not the decode path.

Patches applied

Two patches are required on this nightly for the model to start at all:

PR fix notes

PR #40049: fix(vllm): port skip_kv_gather check and workspace reservation to ROCm MLA sparse attention

Description (problem / solution / changelog)

Summary

Fixes ROCOM_AITER_MLA_SPARSE prefill producing garbage for prompt_len > ~20K tokens on gfx950 (MI355X) by porting two missing patterns from the CUDA reference implementation in sparse_attn_indexer.py.

Root Cause

Two bugs in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py:

  1. Missing skip_kv_gather check: The CUDA path skips the KV gather for sub-chunks marked skip_kv_gather=True (chunks after the first that re-use prior KV data). The ROCm path always re-gathered into fresh torch.empty buffers, causing memory competition during CUDA graph replay.

  2. Missing workspace reservation during profiling: The rocm_aiter_sparse_attn_indexer_fake path allocated buffers with torch.empty directly instead of reserving workspace via current_workspace_manager().get_simultaneous(), leading to incorrect memory accounting during profiling.

Changes

rocm_aiter_sparse_attn_indexer_fake

  • Reserve workspace buffers via current_workspace_manager().get_simultaneous() during profiling, matching the CUDA sparse_attn_indexer.py pattern (lines 55-60)

rocm_aiter_sparse_attn_indexer prefill path

  • Get shared persistent workspace buffers once via workspace_manager.get_simultaneous() and slice per-chunk (avoid per-iteration torch.empty allocation)
  • Add if not chunk.skip_kv_gather: guard around ops.cp_gather_indexer_k_quant_cache call, matching sparse_attn_indexer.py lines 113-130

Testing

Cannot test on AMD MI355X hardware directly. The fix is a direct port of the confirmed-correct CUDA pattern. The issue author noted their patch "helps but doesn't fully fix it" — this port of the complete CUDA pattern should close the gap.

Changed files

  • vllm/v1/attention/ops/rocm_aiter_mla_sparse.py (modified, +32/-24)

Code Example

"""
Reproducer: ROCM_AITER_MLA_SPARSE prefill corruption at large prompt sizes.
Run against a vLLM server serving GLM-5.1-FP8 on MI355X (gfx950).
"""
import requests, hashlib

VLLM_URL = "http://localhost:8005/v1/chat/completions"

def make_diverse_prompt(n_chars):
    """Generate non-repetitive text to avoid tokenizer compression."""
    parts = []
    for i in range(n_chars // 100 + 1):
        h = hashlib.md5(str(i).encode()).hexdigest()
        parts.append(
            f"Function process_{h[:8]}(data_{i}: List[Dict], "
            f"config_{h[8:16]}: Optional[str] = None) -> Tuple[int, str]: "
            f"Validates input parameter {h[16:24]} against schema version "
            f"{i % 50}.{i % 10}, applies transformation rule "
            f"#{i*7 % 1000} from module {h[24:32]}, then returns "
            f"status code {200 + i % 5} with message. "
        )
    return " ".join(parts)[:n_chars]

for target_chars in [20000, 40000, 60000, 80000, 100000]:
    prompt = make_diverse_prompt(target_chars)
    resp = requests.post(VLLM_URL, json={
        "model": "glm-5.1-fp8",
        "messages": [
            {"role": "system", "content": prompt},
            {"role": "user", "content": "What is 2+2? Answer in one sentence."},
        ],
        "max_tokens": 1000,
        "temperature": 0.7,
    }, timeout=300)
    r = resp.json()
    content = r["choices"][0]["message"].get("content") or ""
    reasoning = r["choices"][0]["message"].get("reasoning") or ""
    text = reasoning + content
    words = text.split()
    unique = len(set(w.lower() for w in words)) if words else 0
    ratio = unique / len(words) if words else 0
    pt = r["usage"]["prompt_tokens"]
    ct = r["usage"]["completion_tokens"]
    status = "OK" if ratio > 0.15 else "GARBAGE"
    print(f"prompt={pt:6d}  comp={ct:4d}  unique_ratio={ratio:.3f}  {status}")
    print(f"  first 80 chars: {text[:80]}")
    print()

---

workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
    ((total_seq_lens, head_dim), fp8_dtype),
    ((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunk_specs:
    k_fp8 = k_fp8_full[: chunk.total_seq_lens]
    k_scale = k_scale_full[: chunk.total_seq_lens]
    if not chunk.skip_kv_gather:          # <-- KEY: skip on sub-chunks
        ops.cp_gather_indexer_k_quant_cache(...)

---

for chunk in prefill_metadata.chunk_specs:
    k_fp8 = torch.empty(                  # <-- BUG: fresh buffer each iteration
        [chunk.total_seq_lens, head_dim],
        device=k.device, dtype=fp8_dtype,
    )
    k_scale = torch.empty(...)            # <-- BUG: fresh buffer, uninitialized
    ops.cp_gather_indexer_k_quant_cache(   # <-- BUG: always gathers (no skip check)
        kv_cache, k_fp8, k_scale,
        chunk.block_table, chunk.cu_seq_lens,
    )

---

if not isinstance(attn_metadata, dict):
    current_workspace_manager().get_simultaneous(
        ((total_seq_lens, head_dim), torch.float8_e4m3fn),
        ((total_seq_lens, 4), torch.uint8),
    )
RAW_BUFFERClick to expand / collapse

Summary

The investigation and this report was generated with the help of Claude Code.

The ROCM_AITER_MLA_SPARSE attention backend — the only backend available for glm_moe_dsa sparse-MLA models on ROCm — silently produces corrupt output when the prefill prompt exceeds approximately 20,000 tokens. This is a distinct bug from the decode-path regression reported in #39303 (which was fixed by PR #39509 reverting aiter to v0.1.10.post3). The new bug persists on the fixed nightly and is in the prefill path, not the decode path.

#39303 (decode, FIXED)This issue (prefill, NEW)
PathDecode (paged attention)Prefill (indexer gather + logits)
Kerneldeepgemm_fp8_paged_mqa_logits_stage1rocm_fp8_mqa_logits via rocm_aiter_mla_sparse.py
TriggerRunning decode context > 2048 tokensPrompt (prefill) length > ~20K tokens
aiter versionv0.1.12 only (fixed in v0.1.10.post3)v0.1.10.post3 — still broken
SymptomToken salad after ~2K generated tokensToken salad from first generated token
Root causeaiter kernel regression in v0.1.12Missing skip_kv_gather + workspace reuse in ROCm path

Environment

  • Docker image: docker.io/rocm/vllm-dev:nightly (image id 38936227491d)
  • vLLM version: 0.19.1rc1.dev296+gbcc2306ce
  • aiter version: 0.1.10.post3 (the post-#39509 fix)
  • Model: zai-org/GLM-5.1-FP8 (754B MoE, FP8 E4M3, glm_moe_dsa architecture)
  • Hardware: 4× AMD Instinct MI355X OAM (gfx950:sramecc+:xnack-), 288 GiB HBM3e each
  • TP: 4 (required — aiter ASM MLA kernels only exist for gqa=16, see #39303)
  • Node: mi355-gpu-37
  • Date: 2026-04-16

Patches applied

Two patches are required on this nightly for the model to start at all:

  1. pa_mqa_logits.py ZeroDivisionError clamp — GLM at TP=4 has heads_per_rank=16 < ChunkQ=64, causing heads // ChunkQ == 0TileQCount == 0ZeroDivisionError during CUDA graph capture. Fix: if heads < ChunkQ: ChunkQ = heads; TileQCount = max(1, ...). This is needed on every nightly tested and is orthogonal to this bug.

  2. rocm_aiter_mla_sparse.py workspace + skip_kv_gather patch — see Root Cause section below. This patch improves the threshold from ~20K to ~25K tokens but does not fully fix the issue.

Reproduction

Minimal reproducer

"""
Reproducer: ROCM_AITER_MLA_SPARSE prefill corruption at large prompt sizes.
Run against a vLLM server serving GLM-5.1-FP8 on MI355X (gfx950).
"""
import requests, hashlib

VLLM_URL = "http://localhost:8005/v1/chat/completions"

def make_diverse_prompt(n_chars):
    """Generate non-repetitive text to avoid tokenizer compression."""
    parts = []
    for i in range(n_chars // 100 + 1):
        h = hashlib.md5(str(i).encode()).hexdigest()
        parts.append(
            f"Function process_{h[:8]}(data_{i}: List[Dict], "
            f"config_{h[8:16]}: Optional[str] = None) -> Tuple[int, str]: "
            f"Validates input parameter {h[16:24]} against schema version "
            f"{i % 50}.{i % 10}, applies transformation rule "
            f"#{i*7 % 1000} from module {h[24:32]}, then returns "
            f"status code {200 + i % 5} with message. "
        )
    return " ".join(parts)[:n_chars]

for target_chars in [20000, 40000, 60000, 80000, 100000]:
    prompt = make_diverse_prompt(target_chars)
    resp = requests.post(VLLM_URL, json={
        "model": "glm-5.1-fp8",
        "messages": [
            {"role": "system", "content": prompt},
            {"role": "user", "content": "What is 2+2? Answer in one sentence."},
        ],
        "max_tokens": 1000,
        "temperature": 0.7,
    }, timeout=300)
    r = resp.json()
    content = r["choices"][0]["message"].get("content") or ""
    reasoning = r["choices"][0]["message"].get("reasoning") or ""
    text = reasoning + content
    words = text.split()
    unique = len(set(w.lower() for w in words)) if words else 0
    ratio = unique / len(words) if words else 0
    pt = r["usage"]["prompt_tokens"]
    ct = r["usage"]["completion_tokens"]
    status = "OK" if ratio > 0.15 else "GARBAGE"
    print(f"prompt={pt:6d}  comp={ct:4d}  unique_ratio={ratio:.3f}  {status}")
    print(f"  first 80 chars: {text[:80]}")
    print()

Results

Without rocm_aiter_mla_sparse.py patch (upstream nightly as-is + pa_mqa_logits clamp only):

Prompt tokensCompletionUnique word ratioStatus
5,8611310.625OK
11,6971120.723OK
17,5231,0000.263OK
23,3981,0000.012GARBAGE

Garbage output at 23K: AFC} 8^PU^P}^U^P>^P> ^P> ^P> ^P> ^P> ^P> ^P> ...

With rocm_aiter_mla_sparse.py workspace + skip_kv_gather patch:

Prompt tokensCompletionUnique word ratioStatus
5,8611310.625OK
11,6971120.723OK
17,5231,0000.397OK
23,3981,0000.383OK
29,2841,0000.012GARBAGE

Threshold improved from ~20K to ~25K, but still fails for prompts > ~25K tokens.

Control: small prompt + long generation (same server, same patches):

Prompt tokensCompletionTotal contextUnique word ratioStatus
425,6175,6590.422+OK throughout

6,000 tokens of coherent computing history essay. Sliding-window unique-ratio never drops below 0.74 across the entire output. The decode path is fine — the corruption is entirely in the prefill path.

What was ruled out

HypothesisTestResult
Chunked prefill boundary issue--max-num-batched-tokens 32768 with --max-model-len 32768 (single chunk)Still GARBAGE at 23K prompt tokens
Sub-chunk logits buffer overflowVLLM_SPARSE_INDEXER_MAX_LOGITS_MB=8192 (eliminates sub-chunking)Made it WORSE — GARBAGE at 23K (vs 25K with default 512 MB)
Decode kernel regression (#39303)Confirmed aiter v0.1.10.post3 (post-PR #39509); small-prompt decode works to 5.6K+ tokensDecode is fine; this is prefill
TRITON_MLA fallbackForced use_sparse=False to select TRITON_MLAValueError: No valid attention backend... sparse not supportedno fallback exists

Root cause analysis

The upstream ROCm path vs CUDA reference

The file vllm/v1/attention/ops/rocm_aiter_mla_sparse.py is the ROCm-specific implementation of the sparse-MLA indexer. The CUDA reference lives in vllm/model_executor/layers/sparse_attn_indexer.py. Comparing the two reveals two bugs in the ROCm path, both related to the prefill sub-chunk handling:

Bug 1: Missing skip_kv_gather check (causes garbage on sub-chunk reuse)

The CUDA path (lines 113-130 of sparse_attn_indexer.py):

workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
    ((total_seq_lens, head_dim), fp8_dtype),
    ((total_seq_lens, 4), torch.uint8),
)
for chunk in prefill_metadata.chunk_specs:
    k_fp8 = k_fp8_full[: chunk.total_seq_lens]
    k_scale = k_scale_full[: chunk.total_seq_lens]
    if not chunk.skip_kv_gather:          # <-- KEY: skip on sub-chunks
        ops.cp_gather_indexer_k_quant_cache(...)

The upstream ROCm path (lines ~553-570):

for chunk in prefill_metadata.chunk_specs:
    k_fp8 = torch.empty(                  # <-- BUG: fresh buffer each iteration
        [chunk.total_seq_lens, head_dim],
        device=k.device, dtype=fp8_dtype,
    )
    k_scale = torch.empty(...)            # <-- BUG: fresh buffer, uninitialized
    ops.cp_gather_indexer_k_quant_cache(   # <-- BUG: always gathers (no skip check)
        kv_cache, k_fp8, k_scale,
        chunk.block_table, chunk.cu_seq_lens,
    )

When split_indexer_prefill_chunks produces sub-chunks with skip_kv_gather=True (because the KV data was already gathered by a prior sub-chunk), the ROCm path:

  1. Allocates a fresh torch.empty buffer (uninitialized memory)
  2. Always re-gathers instead of skipping
  3. The fresh buffer does NOT contain the previously gathered data

For the first sub-chunk this accidentally works (it gathers fresh). For subsequent sub-chunks marked skip_kv_gather=True, the CUDA path reuses the shared workspace (which still holds valid data from the first sub-chunk), but the ROCm path allocates garbage-filled memory and re-gathers into a differently-sized buffer.

Bug 2: Missing workspace reservation during profiling

The CUDA path reserves workspace during the profiling run:

if not isinstance(attn_metadata, dict):
    current_workspace_manager().get_simultaneous(
        ((total_seq_lens, head_dim), torch.float8_e4m3fn),
        ((total_seq_lens, 4), torch.uint8),
    )

The upstream ROCm path has no equivalent reservation, so the workspace manager's peak-memory accounting during profiling does not include the indexer buffers. At runtime, when the workspace allocates for real, it may compete with the KV cache for the same memory, potentially causing silent corruption at large sizes.

Why the patch helps but doesn't fully fix it

The rocm_aiter_mla_sparse_ops_patched.py patch (authored 2026-04-08) addresses both bugs above by mirroring the CUDA reference. This pushes the correctness threshold from ~20K to ~25K prompt tokens. However, the underlying rocm_fp8_mqa_logits kernel appears to have its own size-dependent correctness issue beyond ~20-25K total sequence length in the logits computation, which is NOT present in the CUDA path (which uses a different kernel).

The remaining corruption past ~25K is likely in one of:

  • aiter.ops.triton.attention.fp8_mqa_logits.rocm_fp8_mqa_logits — the FP8 MQA logits kernel called per sub-chunk
  • The radix-topk implementation used on ROCm (vs persistent_topk on CUDA)
  • An indexer metadata builder issue with total_seq_lens exceeding some internal buffer

Impact

  • GLM-5.1-FP8 is unusable for applications with prompts > ~25K tokens on MI355X. This includes all agentic coding assistants (system prompt + tools + history routinely exceeds 30K tokens), RAG pipelines, and long-document QA.
  • ROCM_AITER_MLA_SPARSE is the only backend for sparse-MLA models — there is no TRITON_MLA fallback. The backend selector at platforms/rocm.py:362 hardcodes return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE] when use_sparse=True.
  • Short-prompt use cases (< 20K tokens) and decode-only workloads are unaffected.

Suggested fixes

  1. Immediate (P0): Port the two CUDA-path fixes into rocm_aiter_mla_sparse.py:

    • Add skip_kv_gather check mirroring sparse_attn_indexer.py:123
    • Add workspace reservation during profiling mirroring sparse_attn_indexer.py:59
    • Use persistent shared workspace buffers instead of per-iteration torch.empty

    Diff available: compare the upstream rocm_aiter_mla_sparse.py against vllm/model_executor/layers/sparse_attn_indexer.py lines 55-130.

  2. Follow-up: Investigate the residual corruption past ~25K tokens after applying the workspace patch. This may be in rocm_fp8_mqa_logits itself or in the ROCm-specific topk path.

  3. Defense in depth: Add a TRITON_MLA fallback path for use_sparse=True models, even if slower, so there is a correct reference to test against. Currently the only backend is ROCM_AITER_MLA_SPARSE with no escape hatch.

Additional context

  • PR #39509 (revert aiter to v0.1.10.post3) fixed the decode regression from #39303. That fix is confirmed working — decode coherence past the 2048 boundary is verified on this nightly. This new bug is orthogonal.
  • The pa_mqa_logits.py ZeroDivisionError clamp (heads < ChunkQ at TP=4) is still required on all nightlies for GLM models and should be upstreamed separately.
  • The topk_fixed.py patch (moe_fused_gate 256-expert limit for GLM's 256 experts) is needed on aiter v0.1.12 but NOT on v0.1.10.post3 (different dispatch logic).

extent analysis

TL;DR

The most likely fix for the ROCM_AITER_MLA_SPARSE attention backend corruption issue is to port the two CUDA-path fixes into rocm_aiter_mla_sparse.py, including adding a skip_kv_gather check and workspace reservation during profiling.

Guidance

  1. Port CUDA-path fixes: Mirror the skip_kv_gather check and workspace reservation from sparse_attn_indexer.py in rocm_aiter_mla_sparse.py to address the corruption issue.
  2. Verify the fix: Run the minimal reproducer with the patched rocm_aiter_mla_sparse.py to ensure the corruption issue is resolved for prompts up to ~25K tokens.
  3. Investigate residual corruption: After applying the patch, investigate the residual corruption past ~25K tokens to determine if it's related to rocm_fp8_mqa_logits or the ROCm-specific topk path.
  4. Add TRITON_MLA fallback: Consider adding a TRITON_MLA fallback path for use_sparse=True models to provide a correct reference and escape hatch.

Example

# Example of adding skip_kv_gather check in rocm_aiter_mla_sparse.py
for chunk in prefill_metadata.chunk_specs:
    if not chunk.skip_kv_gather:
        ops.cp_gather_indexer_k_quant_cache(...)

Notes

  • The pa_mqa_logits.py ZeroDivisionError clamp is still required on all nightlies for GLM models and should be upstreamed separately.
  • The topk_fixed.py patch is not relevant to this issue.

Recommendation

Apply the workaround by porting the CUDA-path fixes into rocm_aiter_mla_sparse.py to address the corruption issue, and then investigate the residual corruption past ~25K tokens.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING