vllm - ✅(Solved) Fix [Bug]: KDA chunked prefill uses wrong recurrent state layout and breaks Kimi-linear long-context retrieval [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#41292Fetched 2026-04-30 06:18:58
View on GitHub
Comments
2
Participants
2
Timeline
15
Reactions
0
Timeline (top)
mentioned ×4subscribed ×4commented ×2labeled ×2

PR fix notes

PR #33291: [PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K]

Description (problem / solution / changelog)

Summary

This PR changes the recurrent state memory layout in GDN (Gated Delta Net) attention from [N, HV, K, V] to [N, HV, V, K] for improved memory access patterns and throughput.

Behind speedup, also allows to use FI's GDN kernels

Performance Results

Model: nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 (TP=2)

Server:

VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
    -tp 2 --enable-expert-parallel --async-scheduling --no-enable-prefix-caching \
    --compilation_config.max_cudagraph_capture_size 2048

Benchmark:

vllm bench serve --backend vllm --model nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
--endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1000 \
--max-concurrency $CONC --num-prompt $CONC --ignore-eos
Batch SizeBaseline (tok/s)With PR (tok/s)Delta
1199.58201.37+0.9%
162,251.702,225.81-1.2%
646,148.086,088.84-1.0%
25614,420.4014,620.51+1.4%
102423,24524,350+4.8%

Correctness Verification (lm_eval)

Task: GSM8K (5-shot) Model: nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4

Server:

VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
-tp 2 --enable-expert-parallel --async-scheduling --no-enable-prefix-caching \
--compilation_config.max_cudagraph_capture_size 2048

Evaluation:

lm_eval --model local-chat-completions \
    --model_args model=nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=250 \
    --tasks gsm8k --apply_chat_template --num_fewshot 5 --output_path ./eval_results --log_samples
MetricBaselineWith PRDelta
exact_match (flexible-extract)0.77030.7718+0.0015
exact_match (strict-match)0.64060.6368-0.0038

With speculative decoding.

Unfortunatelly we have a problem in case spec decoding+cudagraph. Run without cudagraph. Also used local-completions insted of above local-chat-completions - that produce better accuracy.

Server:

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 -tp 4     --enable-expert-parallel --async-scheduling --no
-enable-prefix-caching     --compilation_config.cudagraph_mode NONE    --speculative_config.method qwen3_next_mtp     --speculative_config.num_speculative_toke
ns 3

Evaluation:

lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:8000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct-FP8,num_concurrent=109;
TasksVersionFiltern-shotMetricValueStderr
gsm8k3flexible-extract5exact_match0.8537±0.0097
strict-match5exact_match0.8143±0.0107

Result the same as baseline.

Changed files

  • vllm/model_executor/layers/fla/ops/chunk.py (modified, +4/-4)
  • vllm/model_executor/layers/fla/ops/chunk_delta_h.py (modified, +35/-35)
  • vllm/model_executor/layers/fla/ops/chunk_o.py (modified, +4/-4)
  • vllm/model_executor/layers/fla/ops/fused_recurrent.py (modified, +16/-16)
  • vllm/model_executor/layers/fla/ops/kda.py (modified, +1/-1)
  • vllm/model_executor/layers/mamba/mamba_utils.py (modified, +1/-1)

Code Example

Collecting environment information...
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : 11.4.0
Libc version                 : glibc-2.35
PyTorch version              : 2.11.0+cu130
CUDA used to build PyTorch   : 13.0
Python version               : 3.12.13
Python platform              : Linux-5.15.0-173-generic-x86_64-with-glibc2.35
Is CUDA available            : True
CUDA runtime version         : 13.0.88
GPU models and configuration :
GPU 0-7: NVIDIA H100 80GB HBM3
Nvidia driver version        : 595.45.04
CPU                          : Intel(R) Xeon(R) Gold 6430
CPU(s)                       : 128
NUMA node(s)                 : 2
[pip3] torch==2.11.0+cu130
[pip3] torchvision==0.26.0+cu130
[pip3] torchaudio==2.11.0+cu130
[pip3] transformers==5.6.2
[pip3] triton==3.6.0
[pip3] flashinfer-python==0.6.8.post1
[pip3] numpy==2.2.6
vLLM Version                 : 0.20.0
vLLM Build Flags             : CUDA Archs: 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX; ROCm: Disabled; XPU: Disabled
GPU Topology                 : all 8 H100 GPUs connected via NV18
Environment:
NVIDIA_VISIBLE_DEVICES=all
CUDA_VERSION=13.0.2
VLLM_USAGE_SOURCE=production-docker-image
VLLM_ENABLE_CUDA_COMPATIBILITY=0
TORCH_CUDA_ARCH_LIST=7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX

---

Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version                : 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.1 26084 f58b06dce1f9c15707c5f808fd002e18c2accf7e)
CMake version                : version 3.31.10
Libc version                 : glibc-2.35
==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+git8514f05
Is debug build               : False
CUDA used to build PyTorch   : N/A
ROCM used to build PyTorch   : 7.2.53211
XPU used to build PyTorch    : N/A
==============================
      Python Environment
==============================
Python version               : 3.12.13 (main, Mar  4 2026, 09:23:07) [GCC 11.4.0] (64-bit runtime)
Python platform              : Linux-6.8.0-110-generic-x86_64-with-glibc2.35
==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : Could not collect
CUDA_MODULE_LOADING set to   :
GPU models and configuration :  (gfx950:sramecc+:xnack-)
Nvidia driver version        : Could not collect
cuDNN version                : Could not collect
HIP runtime version          : 7.2.53211
MIOpen runtime version       : 3.5.1
Is XNNPACK available         : True
==============================
          CPU Info
==============================
Architecture                 : x86_64
CPU(s)                       : 256
Model name                   : AMD EPYC 9575F 64-Core Processor
Socket(s)                    : 2
Core(s) per socket           : 64
Thread(s) per core           : 2
NUMA node(s)                 : 2
==============================
Versions of relevant libraries
==============================
[pip3] conch-triton-kernels==1.2.1
[pip3] numpy==2.1.3
[pip3] onnx==1.19.0
[pip3] torch==2.10.0+git8514f05
[pip3] torchaudio==2.9.0+eaa9e4e
[pip3] torchvision==0.24.1+d801a34
[pip3] transformers==5.6.2
[pip3] triton==3.6.0
[pip3] triton_kernels==1.0.0
[conda] Could not collect
==============================
         vLLM Info
==============================
ROCM Version                 : 7.2.53211-e1a6bc5663
vLLM Version                 : 0.20.0
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology                 : 8 GPUs detected; all GPU pairs connected via XGMI, 1 hop.
NUMA Affinity                : GPU0-3 on NUMA node 0; GPU4-7 on NUMA node 1.
==============================
     Environment Variables
==============================
PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root

---

COBALT PENGUIN WALKS AT DAWN

---

The exact SECRET_PHRASE is: **FILLER 5793: ignore this line.**
RAW_BUFFERClick to expand / collapse

Your current environment

<details> <summary>The output of <code>python collect_env.py</code></summary>

NVIDIA:

Collecting environment information...
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : 11.4.0
Libc version                 : glibc-2.35
PyTorch version              : 2.11.0+cu130
CUDA used to build PyTorch   : 13.0
Python version               : 3.12.13
Python platform              : Linux-5.15.0-173-generic-x86_64-with-glibc2.35
Is CUDA available            : True
CUDA runtime version         : 13.0.88
GPU models and configuration :
GPU 0-7: NVIDIA H100 80GB HBM3
Nvidia driver version        : 595.45.04
CPU                          : Intel(R) Xeon(R) Gold 6430
CPU(s)                       : 128
NUMA node(s)                 : 2
[pip3] torch==2.11.0+cu130
[pip3] torchvision==0.26.0+cu130
[pip3] torchaudio==2.11.0+cu130
[pip3] transformers==5.6.2
[pip3] triton==3.6.0
[pip3] flashinfer-python==0.6.8.post1
[pip3] numpy==2.2.6
vLLM Version                 : 0.20.0
vLLM Build Flags             : CUDA Archs: 7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX; ROCm: Disabled; XPU: Disabled
GPU Topology                 : all 8 H100 GPUs connected via NV18
Environment:
NVIDIA_VISIBLE_DEVICES=all
CUDA_VERSION=13.0.2
VLLM_USAGE_SOURCE=production-docker-image
VLLM_ENABLE_CUDA_COMPATIBILITY=0
TORCH_CUDA_ARCH_LIST=7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX

AMD:

Collecting environment information...
==============================
        System Info
==============================
OS                           : Ubuntu 22.04.5 LTS (x86_64)
GCC version                  : (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version                : 22.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-7.2.1 26084 f58b06dce1f9c15707c5f808fd002e18c2accf7e)
CMake version                : version 3.31.10
Libc version                 : glibc-2.35
==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+git8514f05
Is debug build               : False
CUDA used to build PyTorch   : N/A
ROCM used to build PyTorch   : 7.2.53211
XPU used to build PyTorch    : N/A
==============================
      Python Environment
==============================
Python version               : 3.12.13 (main, Mar  4 2026, 09:23:07) [GCC 11.4.0] (64-bit runtime)
Python platform              : Linux-6.8.0-110-generic-x86_64-with-glibc2.35
==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : Could not collect
CUDA_MODULE_LOADING set to   :
GPU models and configuration :  (gfx950:sramecc+:xnack-)
Nvidia driver version        : Could not collect
cuDNN version                : Could not collect
HIP runtime version          : 7.2.53211
MIOpen runtime version       : 3.5.1
Is XNNPACK available         : True
==============================
          CPU Info
==============================
Architecture                 : x86_64
CPU(s)                       : 256
Model name                   : AMD EPYC 9575F 64-Core Processor
Socket(s)                    : 2
Core(s) per socket           : 64
Thread(s) per core           : 2
NUMA node(s)                 : 2
==============================
Versions of relevant libraries
==============================
[pip3] conch-triton-kernels==1.2.1
[pip3] numpy==2.1.3
[pip3] onnx==1.19.0
[pip3] torch==2.10.0+git8514f05
[pip3] torchaudio==2.9.0+eaa9e4e
[pip3] torchvision==0.24.1+d801a34
[pip3] transformers==5.6.2
[pip3] triton==3.6.0
[pip3] triton_kernels==1.0.0
[conda] Could not collect
==============================
         vLLM Info
==============================
ROCM Version                 : 7.2.53211-e1a6bc5663
vLLM Version                 : 0.20.0
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology                 : 8 GPUs detected; all GPU pairs connected via XGMI, 1 hop.
NUMA Affinity                : GPU0-3 on NUMA node 0; GPU4-7 on NUMA node 1.
==============================
     Environment Variables
==============================
PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root
</details>

🐛 Describe the bug

Kimi Delta Attention (KDA) chunked prefill appears to use an incorrect recurrent state layout after https://github.com/vllm-project/vllm/pull/33291 changed GDN/KDA state layout from [N, HV, K, V] to [N, HV, V, K].

This causes Kimi-linear long-context retrieval failures. Short prompts can remain coherent, but long prefill corrupts KDA state and causes the model to retrieve nearby filler text or start generic responses instead of the target phrase.

I tested on both NVIDIA and AMD GPUs and the behaviors are the same.

End-to-end symptom

Using moonshotai/Kimi-Linear-48B-A3B-Instruct, a needle retrieval test should output

COBALT PENGUIN WALKS AT DAWN

On v0.15.1 the output is correct.

Bad versions (v0.16.0+, including v0.20.0) output generic/filler text such as:

The exact SECRET_PHRASE is: **FILLER 5793: ignore this line.**

First-token logprobs also show the difference before decode:

Good v0.15.1: top token CO Bad v0.16.0+: top token The

Layer instrumentation showed the first material divergence happens during prefill in layer 0 after KimiDeltaAttention; layer 0 norm1 is identical before KDA.

Minimal kernel-level test I created a minimal regression test here: https://github.com/yudigege86/vllm/blob/fix-kimi-kda-state-layout/tests/kernels/test_kda.py

The test:

Creates deterministic Kimi-like KDA tensors. Runs chunk_kda. Computes expected output using existing KDA helper kernels up to w/u/kg, then performs the recurrent state update and final output contraction in PyTorch with the documented [K, V] state layout. Compares chunk_kda output/state against this computed reference.

I reverted https://github.com/vllm-project/vllm/pull/33291 in https://github.com/vllm-project/vllm/compare/main...yudigege86:vllm:fix-kimi-kda-state-layout and both needle and kernel level tests passed.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

TL;DR

Revert the change in state layout from [N, HV, K, V] to [N, HV, V, K] introduced in https://github.com/vllm-project/vllm/pull/33291 to fix the Kimi Delta Attention (KDA) chunked prefill issue.

Guidance

  • Identify the specific commit or version where the state layout change was introduced and revert it to the previous layout.
  • Run the minimal kernel-level test provided in https://github.com/yudigege86/vllm/blob/fix-kimi-kda-state-layout/tests/kernels/test_kda.py to verify the fix.
  • Test the end-to-end symptom using the moonshotai/Kimi-Linear-48B-A3B-Instruct model and the needle retrieval test to ensure the output is correct.
  • Verify that the first-token logprobs show the expected difference before decode, with the top token being CO for the good version.

Example

No code snippet is provided as the issue is related to a specific change in the state layout, and the fix involves reverting that change.

Notes

The issue is specific to the Kimi Delta Attention (KDA) chunked prefill and is caused by the change in state layout. The fix involves reverting that change, and the provided tests can be used to verify the fix.

Recommendation

Apply the workaround by reverting the change in state layout to [N, HV, K, V] as it has been verified to fix the issue through the provided tests.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING