vllm - ✅(Solved) Fix [Bug] flash_attn _get_sliding_window_configs asserts FlashAttentionImpl over all attention layers, breaks any non-FA backend [2 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#39516Fetched 2026-04-11 06:13:06
View on GitHub
Comments
0
Participants
1
Timeline
0
Reactions
0
Participants

Root Cause

The assertion's intent (gather sliding window configs from FA layers) is correct; the implementation is wrong because it iterates the global layer set instead of restricting to layers in this builder's KV cache group.

Fix Action

Fix / Workaround

The same bug surfaces during the GDN/Mamba code path on hybrid models like Qwen3.5-35B-A3B. cc @vibhavagarwal5 — your perf branch vibhavagarwal5/vllm#7 inherits the unmodified flash_attn.py from main, so anyone testing it will hit this without a local patch.

Defensive (1 line): change the assert to a soft-skip:

for layer in layers.values():
    if not isinstance(layer.impl, FlashAttentionImpl):
        continue
    sliding_window_configs.add(layer.impl.sliding_window)

Strictly more permissive than current behavior, no risk of regressing FA-only models. This is what I patched locally to run my benchmarks (H100 throughput data here).

Local workaround until fixed

PR fix notes

PR #7: perf(tq): fuse MSE store ops + inline decode Q cast + cleanup

Description (problem / solution / changelog)

Summary

  • WHT rotation: Replace QR-decomposed random orthogonal matrices with Walsh-Hadamard Transform + random sign flips for key/query rotation. Drop-in replacement (same D×D matmul), orthonormal + self-inverse, enables future in-kernel butterfly fusion
  • Fused MSE store: Bucketize/centroid-gather/residual-norm fused into single Triton kernel (_tq_fused_store_mse), eliminating 4 PyTorch kernel launches per layer
  • In-kernel FP8 cast: FP8 key cast moved from host-side torch.float8_e4m3fn to in-kernel tl.float8e4nv/tl.float8e4b15, removing a separate kernel launch
  • Value quant dedup: Extracted shared _store_quantized_value Triton JIT helper, deduplicating ~60 lines between FP8 and MSE store kernels
  • Prefill .tolist() optimization: Single CPU-GPU sync instead of per-request .item() calls in prefill loop
  • CUDAGraph memory fix: Static NUM_KV_SPLITS=32 reduced estimated memory from 33 GiB → 8.7 GiB
  • Dead code cleanup: Removed unused loggers, kernel constexprs (NUM_Q_HEADS, PADDED_SLOT, MAX_NUM_BLOCKS, N_CENTROIDS), value_packed_size params, stale QR matrix buffer

Benchmark results (Qwen3-4B, 4× RTX PRO 6000 Blackwell, cudagraphs+compile)

Quality remains constant

ConfigK cosV cosPPLNIAHGSM8KInvalid
baseline1.5477/77 (100%)0.9000.000
turboquant_k8v41.5977/77 (100%)0.8600.000
turboquant_4bit_nc1.5377/77 (100%)0.8400.000
turboquant_k3v4_nc1.5277/77 (100%)0.7800.000
turboquant_3bit_nc1.5377/77 (100%)0.7200.000

Throughput (output tok/s)

Scenariobaselinek8v4% baset4nc% basek3v4nc% baset3nc% base
short-decode (128→512)8977711379%639771%620669%611468%
long-prefill (4096→128)85081195%76690%74588%73086%
mixed (512→512)6618527980%482973%458469%449168%
high-load (512→128, n=500)5633475184%445679%433777%424075%
very-long-prefill (8192→64)233234100%22496%22094%21693%
decode-heavy (64→1024)8304652179%588771%565068%543065%

TPOT avg (ms) — lower is better

Scenariobaselinek8v4t4nck3v4nct3nc
short-decode11.915.016.617.217.5
long-prefill138.1135.2142.4146.6149.3
mixed19.323.125.326.627.2
high-load60.965.971.172.173.7
very-long-prefill241.9235.2244.4250.1254.5
decode-heavy12.816.418.018.719.5

TTFT avg (ms) — lower is better

Scenariobaselinek8v4t4nck3v4nct3nc
short-decode305389430461407
long-prefill60956530669068226753
mixed8251014103410771054
high-load18722124214121882198
very-long-prefill1337213633141591439914539
decode-heavy224342293292377

Improvement vs pre-optimization baseline (0408 run)

Decode overhead dropped from ~3x baseline to ~1.3-1.5x — a ~55% reduction in the TQ-vs-baseline gap:

MetricBefore (0408)After (0410)Improvement
k8v4 short-decode tok/s40% of baseline79% of baseline+39pp
t4nc short-decode tok/s40% of baseline71% of baseline+31pp
k8v4 TPOT overhead2.96x baseline1.26x baseline-57%
k8v4 long-prefill tok/s72% of baseline95% of baseline+23pp
k8v4 very-long-prefill tok/s86% of baseline100% of baseline+14pp

Key takeaways

  • k8v4 (FP8 keys + 4-bit values, ~2x compression): 79-100% of baseline throughput
  • t4nc (4-bit MSE + 4-bit values, ~3.8x compression): 71-96% of baseline
  • k3v4nc (3-bit MSE + 4-bit values, ~3.5x compression): 68-94% of baseline
  • k8v4 long-prefill TPOT is faster than baseline (135.2ms vs 138.1ms) — compressed cache reduces memory bandwidth
  • WHT rotation: No regression vs QR; consistent +0.5-2.5% improvement from structured Hadamard cache patterns

Test plan

  • Full perf benchmark (6 scenarios × 5 configs) — no regressions on baseline
  • All TQ configs produce correct output (k8v4, t4nc, k3v4nc, t3nc)
  • CUDAGraph capture verified (51 FULL + 51 PIECEWISE graphs)
  • WHT smoke test: coherent generation across all MSE configs
  • Quality benchmark (PPL/GSM8K) sanity check

🤖 Generated with Claude Code

Changed files

  • vllm/config/attention.py (modified, +5/-0)
  • vllm/model_executor/layers/attention/attention.py (modified, +3/-3)
  • vllm/model_executor/layers/quantization/turboquant/quantizer.py (modified, +18/-3)
  • vllm/v1/attention/backends/turboquant_attn.py (modified, +65/-19)
  • vllm/v1/attention/ops/triton_turboquant_decode.py (modified, +38/-62)
  • vllm/v1/attention/ops/triton_turboquant_store.py (modified, +239/-154)

PR #38479: [Attention Backend] TurboQuant: 2-bit KV cache compression with 4x capacity

Description (problem / solution / changelog)

Summary

TurboQuant adds online KV cache compression to vLLM's v1 attention backend using PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys and uniform quantization for values. All quantization happens at store time via fused Triton kernels — no offline calibration, model changes, or weight modifications required. Just set --kv-cache-dtype turboquant_k8v4.

Compression Presets (Qwen3-4B, head_dim=128)

PresetKeyValueSlot (bytes)CompressionGSM8KNIAH
turboquant_k8v4FP8 (E4M3)4-bit uniform1962.6x0.860100%
turboquant_4bit_nc4-bit MSE + NC4-bit uniform + NC1363.8x0.840100%
turboquant_k3v4_nc3-bit MSE + NC4-bit uniform + NC1204.3x0.780100%
turboquant_3bit_nc3-bit MSE + NC3-bit uniform + NC1044.9x0.720100%

Baseline: GSM8K 0.900, NIAH 100%. Measured on Qwen/Qwen3-4B with 5-shot GSM8K (200q) and NIAH (512-32K, 77 probes).

Performance (Qwen3-4B, 4x RTX PRO 6000 Blackwell, cudagraphs+compile)

Throughput (output tok/s)

ScenarioBaselinek8v4% baset4nc% basek3v4nc% baset3nc% base
short-decode (128→512)8977711379%639771%620669%611468%
long-prefill (4096→128)85081195%76690%74588%73086%
mixed (512→512)6618527980%482973%458469%449168%
high-load (512→128, n=500)5633475184%445679%433777%424075%
very-long-prefill (8192→64)233234100%22496%22094%21693%
decode-heavy (64→1024)8304652179%588771%565068%543065%

TPOT (ms) — lower is better

Scenariobaselinek8v4t4nck3v4nct3nc
short-decode11.915.016.617.217.5
long-prefill138.1135.2142.4146.6149.3
mixed19.323.125.326.627.2
very-long-prefill241.9235.2244.4250.1254.5
decode-heavy12.816.418.018.719.5

TTFT (ms) — lower is better

Scenariobaselinek8v4t4nck3v4nct3nc
short-decode305389430461407
long-prefill60956530669068226753
mixed8251014103410771054
decode-heavy224342293292377

Key Takeaways

  • k8v4 (FP8 keys + 4-bit values, ~2.6x compression): 79-100% of baseline throughput across all scenarios
  • t4nc (4-bit MSE + NC, ~3.8x compression): 71-96% of baseline
  • k8v4 TPOT is faster than baseline on long sequences (135.2ms vs 138.1ms) — compressed cache reduces memory bandwidth pressure
  • Very-long-prefill at parity — 8K→64 shows 100% of baseline tok/s for k8v4

Technical Innovations

Walsh-Hadamard Transform (WHT) rotation — Replaced QR-decomposed random orthogonal matrices with WHT + random sign flips. Orthonormal, self-inverse (H = H^T = H^{-1}), enabling future in-kernel butterfly fusion. Same D×D matmul API, zero quality regression, consistent +0.5-2.5% improvement from structured Hadamard cache patterns. Continuation-prefill inversion is trivially H @ x (no transpose needed).

Fused MSE store kernel — Bucketize, centroid gather, residual norm, index packing, and value quantization fused into a single Triton kernel (_tq_fused_store_mse), eliminating 4 separate PyTorch kernel launches per layer. Result: +18-21% decode throughput, -10-12% prefill TTFT.

In-kernel FP8 cast — FP8 key cast moved from host-side torch.float8_e4m3fn to in-kernel tl.float8e4nv/tl.float8e4b15, removing a separate kernel launch. Auto-detects SM capability for Ampere vs Hopper FP8 formats.

Compact slot sizes — Slots are rounded to next even number instead of power-of-2, eliminating up to 47% padding waste (t4nc: 136B vs 256B). TQFullAttentionSpec properly overrides real_page_size_bytes with compact TQ slot bytes.

Shared value quant JIT helper — Extracted _store_quantized_value Triton JIT function, deduplicating ~60 lines between FP8 and MSE store kernels for both 3-bit and 4-bit value paths.

Prefill .tolist() optimization — Single CPU-GPU sync via .tolist() instead of per-request .item() calls in the prefill loop.

CUDAGraph memory fix — Static NUM_KV_SPLITS grid dimension (configurable, default 32) enables CUDAGraph capture. Estimated GPU memory reduced from 33 GiB → 8.7 GiB.

Stream overlap — KV store runs on a secondary CUDA stream so it can overlap with the next layer's forward pass (disabled during CUDAGraph capture).

Architecture

┌──────────────────────────────────────────────────────────────────┐
│  Store path (Triton)                                            │
│  K → WHT rotation → Lloyd-Max quantize → bit-pack ──┐          │
│  V → uniform quantize → bit-pack ────────────────────┤→ cache   │
│                                                      │          │
│  Decode path (Triton, split-KV)                      │          │
│  cache → unpack K → dequant → Q·K scores ──┐         │          │
│  cache → unpack V → dequant ──→ score·V ───┤→ output │          │
│                                            │         │          │
│  Prefill path (flash_attn_varlen_func)     │         │          │
│  Raw Q, K, V → flash attention → output    │         │          │
│  (continuation decode via TQ decode kernel)│         │          │
└──────────────────────────────────────────────────────────────────┘

Design Decisions

  • Compact even-aligned slots — slots rounded to next even number (not pow2), eliminating up to 47% memory waste. Hybrid mamba+attention models are out of scope for this PR.
  • Boundary layer protection — first/last N layers keep FP16 KV cache via kv_cache_dtype_skip_layers to protect embedding-adjacent representations. Also supports skipping "sliding_window" layers and arbitrary layer indices.
  • TQFullAttentionSpec — proper spec subclass that overrides real_page_size_bytes with TQ slot bytes, with correct merge semantics for uniform-spec models. Passes UniformTypeKVCacheSpecs.is_uniform_type() check as a FullAttentionSpec subclass.
  • No QJL — intentionally omitted per community consensus (5+ independent groups found it hurts attention quality by amplifying variance through softmax).
  • Norm correction (NC) — re-normalizes centroid vectors to unit norm before inverse rotation during dequant, fixing quantization-induced norm distortion (~0.8% PPL improvement at 4-bit).
  • Flash-attention prefill — uses flash_attn_varlen_func for memory-efficient O(N) prefill, with a continuation-decode threshold (128 tokens) routing small chunks directly through the TQ decode kernel.

Usage

# FP8 keys + 4-bit values (best quality/throughput trade-off)
vllm serve Qwen/Qwen3-4B --kv-cache-dtype turboquant_k8v4

# 4-bit MSE keys + 4-bit values + norm correction (3.8x compression)
vllm serve Qwen/Qwen3-4B --kv-cache-dtype turboquant_4bit_nc

# Maximum compression (4.9x)
vllm serve Qwen/Qwen3-4B --kv-cache-dtype turboquant_3bit_nc

# Skip specific layers (boundary protection)
vllm serve Qwen/Qwen3-4B --kv-cache-dtype turboquant_k8v4 \
  --kv-cache-dtype-skip-layers 0,1,34,35

Scope

Supports full-attention and uniform sliding-window transformer models. Hybrid architectures (mamba+attention, interleaved SWA) are planned for a follow-up PR.

Test Plan

  • Full perf benchmark (6 scenarios × 5 configs) — no regressions on baseline
  • All TQ configs produce correct output (k8v4, t4nc, k3v4nc, t3nc)
  • CUDAGraph capture verified (51 FULL + 51 PIECEWISE graphs)
  • WHT rotation: coherent generation across all MSE configs
  • Quality benchmark: GSM8K + NIAH across all presets
  • Mixed batch (decode+prefill) correct routing
  • LM Eval harness integration test

Changed files

  • .buildkite/test_areas/lm_eval.yaml (modified, +10/-0)
  • tests/evals/gsm8k/configs/Qwen3-4B-TQ-k3v4nc.yaml (added, +5/-0)
  • tests/evals/gsm8k/configs/Qwen3-4B-TQ-k8v4.yaml (added, +5/-0)
  • tests/evals/gsm8k/configs/Qwen3-4B-TQ-t3nc.yaml (added, +5/-0)
  • tests/evals/gsm8k/configs/Qwen3-4B-TQ-t4nc.yaml (added, +5/-0)
  • tests/evals/gsm8k/configs/models-turboquant.txt (added, +4/-0)
  • tests/quantization/test_turboquant.py (added, +328/-0)
  • vllm/config/attention.py (modified, +5/-0)
  • vllm/config/cache.py (modified, +4/-0)
  • vllm/engine/arg_utils.py (modified, +18/-0)
  • vllm/model_executor/layers/attention/attention.py (modified, +48/-0)
  • vllm/model_executor/layers/quantization/turboquant/__init__.py (added, +13/-0)
  • vllm/model_executor/layers/quantization/turboquant/centroids.py (added, +89/-0)
  • vllm/model_executor/layers/quantization/turboquant/config.py (added, +185/-0)
  • vllm/model_executor/layers/quantization/turboquant/quantizer.py (added, +39/-0)
  • vllm/platforms/cuda.py (modified, +5/-0)
  • vllm/utils/torch_utils.py (modified, +4/-0)
  • vllm/v1/attention/backends/registry.py (modified, +4/-0)
  • vllm/v1/attention/backends/turboquant_attn.py (added, +775/-0)
  • vllm/v1/attention/ops/triton_turboquant_decode.py (added, +546/-0)
  • vllm/v1/attention/ops/triton_turboquant_store.py (added, +381/-0)
  • vllm/v1/core/single_type_kv_cache_manager.py (modified, +4/-2)
  • vllm/v1/kv_cache_interface.py (modified, +27/-0)

Code Example

def _get_sliding_window_configs(
    vllm_config: VllmConfig,
) -> set[tuple[int, int] | None]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[tuple[int, int] | None] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)  # ← fires on TQ/mamba/GDN/lightning
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs

---

File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flash_attn.py", line 255, in _get_sliding_window_configs
    assert isinstance(layer.impl, FlashAttentionImpl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

---

for layer in layers.values():
    if not isinstance(layer.impl, FlashAttentionImpl):
        continue
    sliding_window_configs.add(layer.impl.sliding_window)

---

sed -i 's|assert isinstance(layer.impl, FlashAttentionImpl)|if not isinstance(layer.impl, FlashAttentionImpl): continue|' \
  $(python -c "import vllm,os; print(os.path.dirname(vllm.__file__))")/v1/attention/backends/flash_attn.py
RAW_BUFFERClick to expand / collapse

Your current environment

  • vLLM nightly 0.19.1rc1.dev188+g8d0f908b9
  • PyTorch 2.11.0+cu130
  • H100 80GB, single GPU, TP=1
  • Qwen3-4B served with --kv-cache-dtype turboquant_3bit_nc (any non-FA backend reproduces — TQ, mamba, GDN, lightning attention)

🐛 Describe the bug

vllm/v1/attention/backends/flash_attn.py:_get_sliding_window_configs iterates all Attention layers in the vllm config and asserts that every layer's impl is a FlashAttentionImpl:

def _get_sliding_window_configs(
    vllm_config: VllmConfig,
) -> set[tuple[int, int] | None]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[tuple[int, int] | None] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)  # ← fires on TQ/mamba/GDN/lightning
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs

This function is called from FlashAttentionMetadataBuilder.build() (flash_attn.py:405). When the model has at least one FlashAttention layer and at least one non-FA layer, the assertion fires and engine init crashes with:

File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/flash_attn.py", line 255, in _get_sliding_window_configs
    assert isinstance(layer.impl, FlashAttentionImpl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

The assertion's intent (gather sliding window configs from FA layers) is correct; the implementation is wrong because it iterates the global layer set instead of restricting to layers in this builder's KV cache group.

Reproduction

Easiest repro is via the in-flight TurboQuant attention backend (#38479) on its current head, but any non-FA backend that coexists with FA layers triggers it. Steps that reproduced for me:

  1. pip install --pre vllm --extra-index-url https://wheels.vllm.ai/nightly
  2. Install or overlay any non-FA backend that registers an Attention.impl of a non-FlashAttentionImpl type
  3. vllm serve Qwen/Qwen3-4B --kv-cache-dtype turboquant_3bit_nc --port 8000 --tensor-parallel-size 1 --max-model-len 8192 --gpu-memory-utilization 0.85
  4. Engine init crashes during the first metadata builder construction

The same bug surfaces during the GDN/Mamba code path on hybrid models like Qwen3.5-35B-A3B. cc @vibhavagarwal5 — your perf branch vibhavagarwal5/vllm#7 inherits the unmodified flash_attn.py from main, so anyone testing it will hit this without a local patch.

Suggested fix

Two options:

Defensive (1 line): change the assert to a soft-skip:

for layer in layers.values():
    if not isinstance(layer.impl, FlashAttentionImpl):
        continue
    sliding_window_configs.add(layer.impl.sliding_window)

Strictly more permissive than current behavior, no risk of regressing FA-only models. This is what I patched locally to run my benchmarks (H100 throughput data here).

Scoped (correct): restrict the layer iteration to layers in this builder's KV cache group, not the global set. Requires tracing the call site (FlashAttentionMetadataBuilder.build) to figure out which group the builder is responsible for. Cleaner but more invasive — would touch the call site too.

Happy to send the defensive PR if a maintainer wants the fast fix; the scoped fix probably needs a design call from someone who touched the per-backend layer-set conventions in #35431 and knows whether this function is meant to be global by design or was just never updated for the multi-backend hybrid manager. cc @LucasWilkinson @MatthewBonanni as the #35431 co-authors.

Local workaround until fixed

sed -i 's|assert isinstance(layer.impl, FlashAttentionImpl)|if not isinstance(layer.impl, FlashAttentionImpl): continue|' \
  $(python -c "import vllm,os; print(os.path.dirname(vllm.__file__))")/v1/attention/backends/flash_attn.py

Before submitting

  • Searched existing issues — didn't find a match. Closest is the broader hybrid-backend issue family, but no specific report on this assertion.

extent analysis

TL;DR

The most likely fix is to change the assert statement in _get_sliding_window_configs to a soft-skip, allowing the function to continue iterating over layers without crashing when encountering non-FlashAttentionImpl layers.

Guidance

  • Identify the layers that are causing the assertion to fail by checking the impl type of each layer in the layers dictionary.
  • Consider implementing a defensive fix by changing the assert statement to a conditional statement that skips layers with non-FlashAttentionImpl implementations.
  • Alternatively, investigate the possibility of restricting the layer iteration to layers in the builder's KV cache group, which may require tracing the call site and understanding the per-backend layer-set conventions.
  • Verify the fix by running the vllm serve command with the modified code and checking that the engine initialization completes successfully.

Example

for layer in layers.values():
    if not isinstance(layer.impl, FlashAttentionImpl):
        continue
    sliding_window_configs.add(layer.impl.sliding_window)

Notes

The defensive fix is a more straightforward solution, but it may not address the underlying issue of why the function is iterating over all layers instead of just the ones in the builder's KV cache group. The scoped fix requires a deeper understanding of the codebase and the intentions of the original authors.

Recommendation

Apply the defensive workaround by changing the assert statement to a soft-skip, as it is a more permissive and less invasive solution that can be implemented quickly. This fix can be used as a temporary solution until a more comprehensive fix can be developed and tested.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [Bug] flash_attn _get_sliding_window_configs asserts FlashAttentionImpl over all attention layers, breaks any non-FA backend [2 pull requests, 1 participants]