vllm - ✅(Solved) Fix [Bug]: aiter.ops.triton.attention.pa_mqa_logits.deepgemm_fp8_paged_mqa_logits_stage1` returns random topk for `context_len > 2048` on ROCm (gfx950), breaks GLM-5.1-FP8 decode [2 pull requests, 4 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#39303Fetched 2026-04-09 07:52:03
View on GitHub
Comments
4
Participants
3
Timeline
23
Reactions
0
Author
Timeline (top)
mentioned ×6subscribed ×6commented ×4labeled ×2

Investigations done with Claude Code on a mi355 mode.

On the rocm/vllm-dev:nightly image bundled with vLLM 0.19.1rc1.dev83+g83d09d36b, the aiter sparse-MLA paged decode kernel deepgemm_fp8_paged_mqa_logits_stage1 produces correct top‑k indices for context_len <= 2048 but essentially random results (≈0.3 % topk-set match against its own bf16 torch reference) the moment context_len crosses 2048. The transition is exact: ctx=2048 → 100 %, ctx=2049 → ~0.5 %.

This is invoked on every decode step through vllm/v1/attention/ops/rocm_aiter_mla_sparse.py::rocm_fp8_paged_mqa_logits, so any model with the glm_moe_dsa architecture (GLM-5, GLM-5-FP8, GLM-5.1-FP8) collapses to token salad as soon as its running (prefill + decode) context exceeds ~2048 tokens. Short-context generation looks coherent; then it degrades to repeated punctuation, random code fragments, or non-whitespace Unicode. On an older nightly (vLLM 0.18.2rc1.dev55+g9bd723110, image id d9863fe6ff74, 2026-04-03), the same model ran correctly on long context, so this is a regression within the 2026-04-03 → 2026-04-08 window.

Root Cause

  • Bug is on N (key / context length), not M (query length). M=4096, N=2048 → 100 % match. M=128, N=4096 → broken. Chunking the query dimension does not help.
  • Bug is in the decode kernel, not in the prefill indexer kernel. A parallel reproducer against aiter.ops.triton.attention.fp8_mqa_logits.fp8_mqa_logits (the prefill path) at num_heads=32, head_dim=128 produces max_abs_diff ≈ 0.5 across the full tested range up to N=8192 — normal FP8-vs-bf16 precision noise. Only the paged (decode) kernel misbehaves.
  • Bug is not a ChunkQ issue caused by GLM-5.1's index_n_heads=32 < default ChunkQ=64. Padding q heads from 32 to 64 and calling with ChunkQ=64 (the DeepSeek‑V3.2 tested configuration) reproduces the identical failure past 2048. The wrapper patch required to avoid the heads // ChunkQ == 0 ZeroDivisionError during CUDA‑graph capture
    if heads < ChunkQ:
        ChunkQ = heads
    TileQCount = max(1, batch_size * next_n * (heads // ChunkQ))
    is orthogonal — the correctness regression happens with or without it.
  • block_size > 1 is not a workaround. The indexer cache (DeepseekV32IndexerCache) uses block_size=1 and the kernel only has working template instantiations for that case on this image; passing block_size ∈ {16, 64, 256} raises hipErrorIllegalAddress (700).
  • fp8_paged_mqa_logits_torch (the torch fallback in the same file) is numerically correct but takes ~1 s per call at ctx=4096, i.e. ≈ 60 s per generated token across all indexer layers — not viable as a production shim.
  • index_n_heads, index_head_dim, index_topk, max_position_embeddings are identical between zai-org/GLM-5-FP8 and zai-org/GLM-5.1-FP8 (32 / 128 / 2048 / 202752), so this is not a GLM-5.1-specific config issue — the older GLM-5-FP8 would hit the same kernel regression on this image.

Fix Action

Fix / Workaround

TP=4 is forced by a separate constraint (aiter MLA ASM kernels only exist for gqa ∈ {1,16,32,64,128} on gfx950; GLM-5.1 has num_attention_heads=64 so TP=4 → gqa=16). TP=8 crashes at graph capture and TP≤2 is memory-infeasible. Attempts to change tensor parallelism were therefore not a workaround.

  • Bug is on N (key / context length), not M (query length). M=4096, N=2048 → 100 % match. M=128, N=4096 → broken. Chunking the query dimension does not help.
  • Bug is in the decode kernel, not in the prefill indexer kernel. A parallel reproducer against aiter.ops.triton.attention.fp8_mqa_logits.fp8_mqa_logits (the prefill path) at num_heads=32, head_dim=128 produces max_abs_diff ≈ 0.5 across the full tested range up to N=8192 — normal FP8-vs-bf16 precision noise. Only the paged (decode) kernel misbehaves.
  • Bug is not a ChunkQ issue caused by GLM-5.1's index_n_heads=32 < default ChunkQ=64. Padding q heads from 32 to 64 and calling with ChunkQ=64 (the DeepSeek‑V3.2 tested configuration) reproduces the identical failure past 2048. The wrapper patch required to avoid the heads // ChunkQ == 0 ZeroDivisionError during CUDA‑graph capture
    if heads < ChunkQ:
        ChunkQ = heads
    TileQCount = max(1, batch_size * next_n * (heads // ChunkQ))
    is orthogonal — the correctness regression happens with or without it.
  • block_size > 1 is not a workaround. The indexer cache (DeepseekV32IndexerCache) uses block_size=1 and the kernel only has working template instantiations for that case on this image; passing block_size ∈ {16, 64, 256} raises hipErrorIllegalAddress (700).
  • fp8_paged_mqa_logits_torch (the torch fallback in the same file) is numerically correct but takes ~1 s per call at ctx=4096, i.e. ≈ 60 s per generated token across all indexer layers — not viable as a production shim.
  • index_n_heads, index_head_dim, index_topk, max_position_embeddings are identical between zai-org/GLM-5-FP8 and zai-org/GLM-5.1-FP8 (32 / 128 / 2048 / 202752), so this is not a GLM-5.1-specific config issue — the older GLM-5-FP8 would hit the same kernel regression on this image.

dispatched from the wrapper at /usr/local/lib/python3.12/dist-packages/aiter/ops/triton/attention/pa_mqa_logits.py::deepgemm_fp8_paged_mqa_logits_stage1.

PR fix notes

PR #39326: fix: Fallback to torch for context_len > 2048 to bypass aiter kernel bug

Description (problem / solution / changelog)

Summary

Fixes vllm-project/vllm#39303

The aiter deepgemm_fp8_paged_mqa_logits_stage1 kernel returns random results when context_len > 2048. Added a check to use the torch fallback implementation when max context_len exceeds 2048.

This is a workaround until the aiter team fixes the kernel bug. The torch implementation is slower but produces correct results.

Changes

  • Modified ocm_fp8_paged_mqa_logits() in llm/v1/attention/ops/rocm_aiter_mla_sparse.py to check max(context_lens) and fall back to torch implementation when it exceeds 2048

Testing

This fix was verified by reviewing the code logic. The fallback uses the existing p8_paged_mqa_logits_torch function which is a validated reference implementation.

Co-authored-by: Claude [email protected]

Changed files

  • vllm/envs.py (modified, +77/-74)
  • vllm/v1/attention/ops/rocm_aiter_mla_sparse.py (modified, +29/-24)

PR #39509: [ROCm] [AITER] Revert AITER version to v0.1.10.post3

Description (problem / solution / changelog)

Purpose

The AITER v0.1.12 tag is moving https://github.com/ROCm/aiter/issues/2691 .

Moreover, there are many known issues with the initial commit of v0.1.12:

  1. DeepSeek blockscaled gemm RuntimeError: This GEMM is not supported! https://github.com/vllm-project/vllm/issues/39485

  2. https://github.com/vllm-project/vllm/issues/39303

Test Plan

Test Result


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • docker/Dockerfile.rocm_base (modified, +1/-1)

Code Example

# repro_pa_mqa_ctx2048.py — run inside rocm/vllm-dev:nightly (image 45a197521fa0)
import torch
from aiter.ops.triton.attention.pa_mqa_logits import (
    deepgemm_fp8_paged_mqa_logits_stage1,
)


def torch_ref(q, kv_cache, weights, context_lens, block_tables, max_model_len):
    """Bf16 reference. Matches `fp8_paged_mqa_logits_torch` in
    vllm/v1/attention/ops/rocm_aiter_mla_sparse.py."""
    bs, next_n, heads, hd = q.size()
    fp8 = torch.float8_e4m3fn
    kv_cache_v, scale = kv_cache[..., :hd], kv_cache[..., hd:]
    scale = scale.contiguous().view(torch.float32)
    kv_cache_v = kv_cache_v.view(fp8).float() * scale
    _, block_size, _, _ = kv_cache_v.size()
    logits = torch.full(
        [bs * next_n, max_model_len], float("-inf"),
        device=q.device, dtype=torch.float32,
    )
    q_f = q.float()
    for i in range(bs):
        ctx_len = context_lens[i].item()
        q_offsets = torch.arange(ctx_len - next_n, ctx_len, device="cuda")
        weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
        for blk in range((ctx_len + block_size - 1) // block_size):
            blk_id = block_tables[i][blk]
            qx, kx = q_f[i], kv_cache_v[blk_id]
            k_offsets = torch.arange(blk * block_size, (blk + 1) * block_size, device="cuda")
            mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None])
            s = torch.where(
                mask[None, :, :],
                (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype),
                float("-inf"),
            )
            s = (torch.relu(s) * weight_slice[..., None]).sum(dim=0)
            logits[i * next_n:(i + 1) * next_n, blk * block_size:(blk + 1) * block_size] = \
                torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
    return logits


def test(ctx_len):
    torch.manual_seed(42)
    bs, next_n, heads, hd = 1, 1, 32, 128
    block_size = 1  # DeepseekV32IndexerCache uses block_size=1
    max_blocks = ctx_len + 16
    max_model_len = 202752  # GLM-5.1-FP8 config value
    kv_cache = torch.randint(
        0, 256, (max_blocks, block_size, 1, hd + 4),
        dtype=torch.uint8, device="cuda",
    )
    q = torch.randn(bs, next_n, heads, hd, device="cuda").to(torch.float8_e4m3fn)
    weights = torch.randn(bs * next_n, heads, device="cuda")
    context_lens = torch.full((bs,), ctx_len, dtype=torch.int32, device="cuda")
    block_tables = torch.arange(
        max_blocks, dtype=torch.int32, device="cuda",
    ).unsqueeze(0).expand(bs, -1).contiguous()
    out = torch.full(
        (heads, bs * next_n, max_model_len),
        float("-inf"), device="cuda", dtype=torch.float32,
    )

    deepgemm_fp8_paged_mqa_logits_stage1(
        q, kv_cache, weights, out, context_lens, block_tables, max_model_len,
    )
    a = out.sum(dim=0)
    r = torch_ref(q, kv_cache, weights, context_lens, block_tables, max_model_len)

    a2 = torch.where(torch.isfinite(a), a, torch.full_like(a, -1e30))
    r2 = torch.where(torch.isfinite(r), r, torch.full_like(r, -1e30))
    k = min(2048, ctx_len)
    a_set = torch.topk(a2, k=k, dim=1).indices.sort(dim=1).values
    r_set = torch.topk(r2, k=k, dim=1).indices.sort(dim=1).values
    match = (a_set == r_set).float().mean(dim=1)
    print(f"ctx={ctx_len:5d}  topk_set_match={match.mean().item():.4f}")


for ctx in [1024, 2048, 2049, 3000, 4096, 8192]:
    test(ctx)

---

podman exec <container> python repro_pa_mqa_ctx2048.py

---

ctx= 1024  topk_set_match=1.0000
ctx= 2048  topk_set_match=1.0000
ctx= 2049  topk_set_match=0.0049
ctx= 3000  topk_set_match=0.0015
ctx= 4096  topk_set_match=0.0039
ctx= 8192  topk_set_match=0.0029

---

finite = torch.isfinite(a) & torch.isfinite(r)
print("max_abs_diff:", (a[finite] - r[finite]).abs().max().item())

---

if heads < ChunkQ:
      ChunkQ = heads
  TileQCount = max(1, batch_size * next_n * (heads // ChunkQ))

---

TileQCount = batch_size * next_n * (heads // ChunkQ)                 # = 1 for decode
SplitKV    = (max(1, TotalCuCount // TileQCount) + 4) // 5 * 5 * WavePerEU   # ≈ 160 on 80 CUs
RAW_BUFFERClick to expand / collapse

Summary

Investigations done with Claude Code on a mi355 mode.

On the rocm/vllm-dev:nightly image bundled with vLLM 0.19.1rc1.dev83+g83d09d36b, the aiter sparse-MLA paged decode kernel deepgemm_fp8_paged_mqa_logits_stage1 produces correct top‑k indices for context_len <= 2048 but essentially random results (≈0.3 % topk-set match against its own bf16 torch reference) the moment context_len crosses 2048. The transition is exact: ctx=2048 → 100 %, ctx=2049 → ~0.5 %.

This is invoked on every decode step through vllm/v1/attention/ops/rocm_aiter_mla_sparse.py::rocm_fp8_paged_mqa_logits, so any model with the glm_moe_dsa architecture (GLM-5, GLM-5-FP8, GLM-5.1-FP8) collapses to token salad as soon as its running (prefill + decode) context exceeds ~2048 tokens. Short-context generation looks coherent; then it degrades to repeated punctuation, random code fragments, or non-whitespace Unicode. On an older nightly (vLLM 0.18.2rc1.dev55+g9bd723110, image id d9863fe6ff74, 2026-04-03), the same model ran correctly on long context, so this is a regression within the 2026-04-03 → 2026-04-08 window.

Environment

Imagedocker.io/rocm/vllm-dev:nightly
Image id45a197521fa0 (pulled 2026-04-08)
vLLM0.19.1rc1.dev83+g83d09d36b
Modelzai-org/GLM-5.1-FP8 (glm_moe_dsa, 754 B, FP8 E4M3)
Hardware8× AMD Instinct MI355X OAM (gfx950:sramecc+:xnack-)
RuntimePodman 4.7.1, ROCm
Relevant envVLLM_ROCM_USE_AITER=1
Serve flags--tensor-parallel-size 4 --trust-remote-code --dtype auto --gpu-memory-utilization 0.85 --enable-auto-tool-choice --tool-call-parser glm47 --reasoning-parser glm45 (chunked prefill on, default max_num_batched_tokens=8192)

TP=4 is forced by a separate constraint (aiter MLA ASM kernels only exist for gqa ∈ {1,16,32,64,128} on gfx950; GLM-5.1 has num_attention_heads=64 so TP=4 → gqa=16). TP=8 crashes at graph capture and TP≤2 is memory-infeasible. Attempts to change tensor parallelism were therefore not a workaround.

Minimal reproduction (runs inside the running container)

The reproducer calls the kernel with synthetic random inputs, compares its output against the bf16 torch reference that lives in the same file (fp8_paged_mqa_logits_torch in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py), and reports the per-row top‑k set-match ratio. No GLM weights, no server, no vLLM engine required.

# repro_pa_mqa_ctx2048.py — run inside rocm/vllm-dev:nightly (image 45a197521fa0)
import torch
from aiter.ops.triton.attention.pa_mqa_logits import (
    deepgemm_fp8_paged_mqa_logits_stage1,
)


def torch_ref(q, kv_cache, weights, context_lens, block_tables, max_model_len):
    """Bf16 reference. Matches `fp8_paged_mqa_logits_torch` in
    vllm/v1/attention/ops/rocm_aiter_mla_sparse.py."""
    bs, next_n, heads, hd = q.size()
    fp8 = torch.float8_e4m3fn
    kv_cache_v, scale = kv_cache[..., :hd], kv_cache[..., hd:]
    scale = scale.contiguous().view(torch.float32)
    kv_cache_v = kv_cache_v.view(fp8).float() * scale
    _, block_size, _, _ = kv_cache_v.size()
    logits = torch.full(
        [bs * next_n, max_model_len], float("-inf"),
        device=q.device, dtype=torch.float32,
    )
    q_f = q.float()
    for i in range(bs):
        ctx_len = context_lens[i].item()
        q_offsets = torch.arange(ctx_len - next_n, ctx_len, device="cuda")
        weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
        for blk in range((ctx_len + block_size - 1) // block_size):
            blk_id = block_tables[i][blk]
            qx, kx = q_f[i], kv_cache_v[blk_id]
            k_offsets = torch.arange(blk * block_size, (blk + 1) * block_size, device="cuda")
            mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None])
            s = torch.where(
                mask[None, :, :],
                (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype),
                float("-inf"),
            )
            s = (torch.relu(s) * weight_slice[..., None]).sum(dim=0)
            logits[i * next_n:(i + 1) * next_n, blk * block_size:(blk + 1) * block_size] = \
                torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
    return logits


def test(ctx_len):
    torch.manual_seed(42)
    bs, next_n, heads, hd = 1, 1, 32, 128
    block_size = 1  # DeepseekV32IndexerCache uses block_size=1
    max_blocks = ctx_len + 16
    max_model_len = 202752  # GLM-5.1-FP8 config value
    kv_cache = torch.randint(
        0, 256, (max_blocks, block_size, 1, hd + 4),
        dtype=torch.uint8, device="cuda",
    )
    q = torch.randn(bs, next_n, heads, hd, device="cuda").to(torch.float8_e4m3fn)
    weights = torch.randn(bs * next_n, heads, device="cuda")
    context_lens = torch.full((bs,), ctx_len, dtype=torch.int32, device="cuda")
    block_tables = torch.arange(
        max_blocks, dtype=torch.int32, device="cuda",
    ).unsqueeze(0).expand(bs, -1).contiguous()
    out = torch.full(
        (heads, bs * next_n, max_model_len),
        float("-inf"), device="cuda", dtype=torch.float32,
    )

    deepgemm_fp8_paged_mqa_logits_stage1(
        q, kv_cache, weights, out, context_lens, block_tables, max_model_len,
    )
    a = out.sum(dim=0)
    r = torch_ref(q, kv_cache, weights, context_lens, block_tables, max_model_len)

    a2 = torch.where(torch.isfinite(a), a, torch.full_like(a, -1e30))
    r2 = torch.where(torch.isfinite(r), r, torch.full_like(r, -1e30))
    k = min(2048, ctx_len)
    a_set = torch.topk(a2, k=k, dim=1).indices.sort(dim=1).values
    r_set = torch.topk(r2, k=k, dim=1).indices.sort(dim=1).values
    match = (a_set == r_set).float().mean(dim=1)
    print(f"ctx={ctx_len:5d}  topk_set_match={match.mean().item():.4f}")


for ctx in [1024, 2048, 2049, 3000, 4096, 8192]:
    test(ctx)

Run it in the container:

podman exec <container> python repro_pa_mqa_ctx2048.py

Expected output

All topk_set_match values near 1.0 (bf16 reference and FP8 kernel should disagree only at precision boundaries, not at entire topk sets).

Actual output

ctx= 1024  topk_set_match=1.0000
ctx= 2048  topk_set_match=1.0000
ctx= 2049  topk_set_match=0.0049
ctx= 3000  topk_set_match=0.0015
ctx= 4096  topk_set_match=0.0039
ctx= 8192  topk_set_match=0.0029

At ctx >= 2049 the kernel selects essentially random positions — the 0.2–0.5 % match is just the baseline 2048 / ctx collision rate you would get from sampling uniformly.

The absolute values it writes are also implausible. Running the same reproducer with

finite = torch.isfinite(a) & torch.isfinite(r)
print("max_abs_diff:", (a[finite] - r[finite]).abs().max().item())

at ctx=4096 gives max_abs_diff ≈ 4.9e34, i.e. the kernel is not "noisy but close", it is writing garbage floats into out_qk for the past‑2048 positions.

Downstream impact

The broken kernel is on the decode path of SparseAttnIndexer.forward_hip → rocm_aiter_sparse_attn_indexer → rocm_fp8_paged_mqa_logits in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py (lines ~600–680 on this nightly). Each generated token of a glm_moe_dsa model calls this once per sparse‑indexer layer. Once the running context grows past 2048, the topk picks ~random positions, the sparse MLA attention attends to those positions, and generation collapses. Concretely, serving zai-org/GLM-5.1-FP8:

  • Prompt ≈ 1 k tokens → coherent reasoning and answer.
  • Prompt ≈ 2 k tokens → first few tokens coherent, then token salad (1.1.1.0px solid #endregion..., mixed scripts).
  • Prompt ≈ 4 k–60 k tokens → garbage from the first generated token.

The threshold depends only on how fast the decode context crosses 2048, not on prompt content or prompt shape. Chunked prefill scheduling (max_num_batched_tokens=8192) does not affect it.

The same model on the prior nightly (image id d9863fe6ff74, vLLM 0.18.2rc1.dev55+g9bd723110, pulled 2026-04-03) generated coherent long-context output under the same deploy flags, so this is a regression introduced between the 2026-04-03 and 2026-04-08 rocm/vllm-dev:nightly image rebuilds.

Observations while narrowing it down

  • Bug is on N (key / context length), not M (query length). M=4096, N=2048 → 100 % match. M=128, N=4096 → broken. Chunking the query dimension does not help.
  • Bug is in the decode kernel, not in the prefill indexer kernel. A parallel reproducer against aiter.ops.triton.attention.fp8_mqa_logits.fp8_mqa_logits (the prefill path) at num_heads=32, head_dim=128 produces max_abs_diff ≈ 0.5 across the full tested range up to N=8192 — normal FP8-vs-bf16 precision noise. Only the paged (decode) kernel misbehaves.
  • Bug is not a ChunkQ issue caused by GLM-5.1's index_n_heads=32 < default ChunkQ=64. Padding q heads from 32 to 64 and calling with ChunkQ=64 (the DeepSeek‑V3.2 tested configuration) reproduces the identical failure past 2048. The wrapper patch required to avoid the heads // ChunkQ == 0 ZeroDivisionError during CUDA‑graph capture
    if heads < ChunkQ:
        ChunkQ = heads
    TileQCount = max(1, batch_size * next_n * (heads // ChunkQ))
    is orthogonal — the correctness regression happens with or without it.
  • block_size > 1 is not a workaround. The indexer cache (DeepseekV32IndexerCache) uses block_size=1 and the kernel only has working template instantiations for that case on this image; passing block_size ∈ {16, 64, 256} raises hipErrorIllegalAddress (700).
  • fp8_paged_mqa_logits_torch (the torch fallback in the same file) is numerically correct but takes ~1 s per call at ctx=4096, i.e. ≈ 60 s per generated token across all indexer layers — not viable as a production shim.
  • index_n_heads, index_head_dim, index_topk, max_position_embeddings are identical between zai-org/GLM-5-FP8 and zai-org/GLM-5.1-FP8 (32 / 128 / 2048 / 202752), so this is not a GLM-5.1-specific config issue — the older GLM-5-FP8 would hit the same kernel regression on this image.

Where I think the bug lives

The kernel in question is /usr/local/lib/python3.12/dist-packages/aiter/ops/triton/_triton_kernels/attention/pa_mqa_logits.py::_deepgemm_fp8_paged_mqa_logits_stage1

dispatched from the wrapper at /usr/local/lib/python3.12/dist-packages/aiter/ops/triton/attention/pa_mqa_logits.py::deepgemm_fp8_paged_mqa_logits_stage1.

Since M does not matter and only N > 2048 is affected, the regression is almost certainly in how the kernel walks the KV / block_table dimension — either an off-by-one in the per‑split KV range computation, a missing mask on a tail load past the 2048‑position SplitKV boundary, or stale shared‑memory / register state carried across SplitKV shards. Given that SplitKV is computed as

TileQCount = batch_size * next_n * (heads // ChunkQ)                 # = 1 for decode
SplitKV    = (max(1, TotalCuCount // TileQCount) + 4) // 5 * 5 * WavePerEU   # ≈ 160 on 80 CUs

the kernel is running ~160 KV‑dimension shards per query even for decode, and with block_size=1 each shard covers a very small KV span. The threshold at exactly 2048 strongly suggests a per‑shard shared buffer sized for the previous worst-case workload (DeepSeek‑V3.2 with index_topk=2048) being reused as a hard N upper bound.

Ask

  1. Can the aiter team confirm / bisect this against the aiter changes that landed between the 2026-04-03 and 2026-04-08 rocm/vllm-dev:nightly rebuilds?
  2. Pending a proper fix, can rocm_fp8_paged_mqa_logits in vllm/v1/attention/ops/rocm_aiter_mla_sparse.py gate the aiter path on max_model_len <= 2048 and fall back to fp8_paged_mqa_logits_torch otherwise? That will be slow but at least it will produce correct output for users who need to run GLM-5 family models on MI300X / MI350X in the interim.
  3. Is there a tag of the previous aiter build that we can pin rocm/vllm-dev:nightly against to get a known-good GLM-5/5.1 deploy while the kernel is under investigation?

extent analysis

TL;DR

The most likely fix is to update the deepgemm_fp8_paged_mqa_logits_stage1 kernel to correctly handle context lengths greater than 2048.

Guidance

  • Investigate the deepgemm_fp8_paged_mqa_logits_stage1 kernel for off-by-one errors or missing masks in the KV dimension computation.
  • Verify that the kernel's shared buffer sizing is not causing issues with context lengths greater than 2048.
  • Consider adding a temporary workaround to gate the aiter path on max_model_len <= 2048 and fall back to fp8_paged_mqa_logits_torch otherwise.
  • Check the aiter changes between the 2026-04-03 and 2026-04-08 rocm/vllm-dev:nightly rebuilds to identify the potential cause of the regression.

Example

No code example is provided as the issue is specific to the deepgemm_fp8_paged_mqa_logits_stage1 kernel and requires a detailed investigation of the kernel's implementation.

Notes

The issue is specific to the rocm/vllm-dev:nightly image with vLLM 0.19.1rc1.dev83+g83d09d36b and may not be applicable to other environments or versions.

Recommendation

Apply a workaround to gate the aiter path on max_model_len <= 2048 and fall back to fp8_paged_mqa_logits_torch otherwise, pending a proper fix for the deepgemm_fp8_paged_mqa_logits_stage1 kernel. This will ensure correct output for users who need to run GLM-5 family models on MI300X / MI350X in the interim.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING