vllm - ✅(Solved) Fix [RFC]: Extensible Per-Token Quantized KV Cache Scale Infrastructure [4 pull requests, 3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37319Fetched 2026-04-08 00:53:32
View on GitHub
Comments
3
Participants
2
Timeline
16
Reactions
1
Author
Participants
Timeline (top)
mentioned ×5subscribed ×5commented ×3closed ×1

Root Cause

VLLM already supports FP8 quantization for the KV cache. This path reduces the memory footprint and allows keeping more tokens in cache, but it relies on per-tensor scales that either come from the checkpoint or are calculated during calibration. The quality cost is quite bounded and easy to reason about because the scheme already exists and its assumptions are well known.

Fix Action

Fix / Workaround

1. Enum dispatch instead of string matching

One of the foundational changes is moving away from the kv_cache_dtype.startswith("fp8") checks scattered throughout the code. Instead, a KVQuantMode enum is introduced in kv_cache_interface.py with three active modes: NONE, FP8, and PER_TOKEN.

The kernel launches one instance per token, computes a global absmax over heads and head_size, derives scale = max(absmax / QUANT_MAX, 1e-6), saves that scale in k_scale_cache[block, slot], and quantizes using tl.clamp(val / scale, QUANT_MIN, QUANT_MAX). The kernel is parameterized using QUANT_MAX and QUANT_MIN, leaving the dispatch ready for other 1-byte formats in the future.

Triton read path

In TritonAttentionImpl.forward(), the previous FP8 logic is reorganized as an explicit three-way dispatch based on self.kv_quant_mode:

PR fix notes

PR #36893: [Feature] Kvcache Int8 per-token scale on TRITON_ATTN continue of #34327 thanks EricccYang

Description (problem / solution / changelog)

FIX https://github.com/vllm-project/vllm/issues/37319

This PR adds INT8 per-token KV cache quantization to vLLM for the Triton attention backend.

Unlike the existing FP8 path, this mode uses dynamic per-token scales computed at cache-write time. In practice, this reduces KV cache memory roughly by half compared to bf16/fp16, offering a "plug and play" zero-config option without requiring pre-calibrated checkpoint scales.

This change introduces a shared INT8 KV cache contract for attention backends: dispatch via the KVQuantMode enum, declare auxiliary buffers through AttentionSpec.auxiliary_buffer_specs, allocate them upfront in the model runner, bind them via AttentionImpl.bind_auxiliary_buffers(), and account for their memory accurately. Triton is the first backend implementing this contract.

What changed

  • Added int8_per_token as a supported kv_cache_dtype.
  • Introduced the KVQuantMode enum so backends dispatch on semantic quantization mode instead of string matching prefixes.
  • Replaced lazy allocation with explicit upfront allocation of scale caches via _allocate_auxiliary_buffers() and bind_auxiliary_buffers().
  • Added a dedicated Triton INT8 reshape-and-cache kernel for per-token dynamic quantization.
  • Extended Triton unified attention to perform an explicit 3-way dispatch and read per-token INT8 scale caches during dequantization.
  • Accounted for the explicit scale-cache memory through auxiliary_buffer_specs in the KV block-count calculation.

Design notes

The main idea is to offer a memory-efficient KV cache option that requires no additional configuration. The user only needs to set kv_cache_dtype="int8_per_token".

Instead of extending the existing FP8 kernel, INT8 uses a separate reshape kernel because it requires dynamic scale computation. The kernel computes an absmax across all heads for a given token, deriving a single scale. This shifts the auxiliary memory footprint to strictly per-token (rather than per-token per-head), optimizing the buffer size while maintaining bounded accuracy costs.

Backend Extensibility

While Triton is the first target backend in this PR, the redesign lays the groundwork for future quantization modes and other backends without requiring core architectural rewrites.

The new KVQuantMode enum already declares PER_TOKEN_GROUP and NVFP4 as future modes. Thanks to the explicit auxiliary_buffer_specs abstraction, supporting these will be straightforward:

  • PER_TOKEN_GROUP can easily declare and allocate two float32 buffers with shape (block_size, num_kv_heads, ceil(head_size / group_size)).
  • NVFP4 can request its required block-scales and global-scales through the same unified interface.

Triton implementation

When the user sets --kv-cache-dtype int8_per_token, process_weights_after_loading() forces checkpoint scales to 1.0 and deletes them, returning early because actual scales are dynamically generated.

For Triton, the write path uses _reshape_cache_per_token. It launches one instance per token, computes a global absmax over heads, derives scale = max(absmax / QUANT_MAX, 1e-6), and writes the quantized data alongside the float32 scales into the auxiliary scale buffers.

On the read path, forward() performs a dispatch based on self.kv_quant_mode. For PER_TOKEN, it passes k_scale_cache and v_scale_cache into unified_attention(), and the Triton attention kernels fetch and apply these per-token scales during the dot product accumulation.

Memory accounting

Until now, KV cache auxiliary memory was treated too flatly. Now, KVCacheSpec exposes auxiliary_buffer_specs.

For PER_TOKEN, it explicitly declares two float32 buffers with shape (block_size,). kv_cache_utils.py is updated to include this specific memory footprint when deciding how many pages fit into the available budget, ensuring the block count reflects the actual footprint of both the quantized tensor and the explicit float32 buffers.

Benchmark Results

Hardware: AMD RX 7900 XTX (gfx1100, no FP8 support)
Backend: vLLM 0.17.1rc1 + TRITON_ATTN + ROCm
Task: ChartQA — 2500 samples, 0-shot (lm-evaluation-harness)
Model weights: GPTQ W4A16-G32 for both models

ModelKV Cacheanywhere_accuracyexact_matchrelaxed_accuracy
Qwen3.5-27B (dense)FP160.8844 ± 0.00640.6232 ± 0.00970.8576 ± 0.0070
Qwen3.5-27B (dense)Int8 per-token0.8816 ± 0.00650.6184 ± 0.00970.8532 ± 0.0071
Qwen3.5-35B-A3B (MoE)FP160.6888 ± 0.00930.4264 ± 0.00990.6272 ± 0.0097
Qwen3.5-35B-A3B (MoE)Int8 per-token0.6836 ± 0.00930.4108 ± 0.00980.6172 ± 0.0097

Qwen 3.5 A3B — fp16 vs int8 kvcache vllm bench benchmark

35B model

MetricFP16INT8Delta
Throughput
Requests/s4.704.80+2.1% ✓
Output tokens/s601.63614.76+2.2% ✓
Peak output tokens/s1100.001100.00=
Total tokens/s3008.163073.82+2.2% ✓
Benchmark duration (s)21.2820.82-2.2% ✓
Time to First Token (TTFT)
Mean TTFT (ms)4897.014906.67+0.2% ~
Median TTFT (ms)4771.424783.90+0.3% ~
P99 TTFT (ms)9761.999768.42+0.1% ~
Time per Output Token (TPOT)
Mean TPOT (ms)123.18119.64-2.9% ✓
Median TPOT (ms)124.57120.94-2.9% ✓
P99 TPOT (ms)150.46147.20-2.2% ✓
Inter-token Latency (ITL)
Mean ITL (ms)123.18119.64-2.9% ✓
Median ITL (ms)97.0292.53-4.6% ✓
P99 ITL (ms)409.39406.12-0.8% ~

27B model

MetricFP16INT8Delta
Throughput
Requests/s1.591.59=
Output tokens/s203.16204.12+0.5% ✓
Peak output tokens/s496.00496.00=
Total tokens/s1015.821020.59+0.5% ✓
Benchmark duration (s)63.0062.71-0.5% ✓
Time to First Token (TTFT)
Mean TTFT (ms)24291.1724183.78-0.4% ~
Median TTFT (ms)18213.3318141.38-0.4% ~
P99 TTFT (ms)51138.6450902.45-0.5% ~
Time per Output Token (TPOT)
Mean TPOT (ms)218.50217.44-0.5% ~
Median TPOT (ms)260.29259.00-0.5% ~
P99 TPOT (ms)280.45279.14-0.5% ~
Inter-token Latency (ITL)
Mean ITL (ms)218.50217.44-0.5% ~
Median ITL (ms)134.38133.51-0.6% ~
P99 ITL (ms)1426.801419.88-0.5% ~

Analysis:

On the dense model the accuracy delta is essentially noise — less than 0.005 across all metrics, well within the statistical margin (stderr ~0.007). In practice you wouldn't notice the difference.

The MoE model is a bit more sensitive to quantization (~0.010–0.016 delta), which makes sense given that expert routing produces more heterogeneous KV distributions. Still very much within acceptable bounds for a 50% memory saving, and notably better than what you'd expect from a per-tensor INT8 scheme.

Both models required no calibration — scales are computed dynamically at runtime with no offline prep.

Changed files

  • docs/design/attention_backends.md (modified, +1/-1)
  • tests/models/quantization/test_int8_kv_cache.py (added, +91/-0)
  • tests/quantization/test_int8_kv_cache.py (added, +732/-0)
  • vllm/config/cache.py (modified, +11/-2)
  • vllm/model_executor/layers/attention/attention.py (modified, +4/-0)
  • vllm/model_executor/layers/attention/chunked_local_attention.py (modified, +2/-0)
  • vllm/model_executor/layers/attention/cross_attention.py (modified, +6/-1)
  • vllm/model_executor/layers/attention/static_sink_attention.py (modified, +2/-0)
  • vllm/model_executor/layers/quantization/kv_cache.py (modified, +18/-1)
  • vllm/utils/torch_utils.py (modified, +1/-0)
  • vllm/v1/attention/backend.py (modified, +24/-2)
  • vllm/v1/attention/backends/triton_attn.py (modified, +64/-17)
  • vllm/v1/attention/ops/triton_reshape_and_cache_flash.py (modified, +208/-1)
  • vllm/v1/attention/ops/triton_unified_attention.py (modified, +135/-34)
  • vllm/v1/core/kv_cache_utils.py (modified, +16/-5)
  • vllm/v1/kv_cache_interface.py (modified, +117/-1)
  • vllm/v1/worker/gpu_model_runner.py (modified, +67/-0)

PR #5: [Feature] KV cache per-token-head INT4 quantization support

Description (problem / solution / changelog)

Summary

This PR adds int4_per_token_head KV cache dtype, extending the per-token-head quantization design from #38378.

  • packed int4 layout (2 channels per byte) with fp16 scales grouped by 32 channels, stored inline in the KV cache tensor
  • Hadamard rotation on query and key before caching to reduce quantization error for Gaussian-like activations
  • AttentionSpec.page_size_bytes / real_page_size_bytes updated to account for the int4 packed size and fp16 scale overhead
  • fix profiling-time padded-page reshape to use as_strided instead of .view() when page sizes are padded

Related: #37319, #36893, #34327

Validation

pytest tests/quantization/test_per_token_kv_cache.py -v
TEST_PER_TOKEN_KV_CACHE_MODEL=meta-llama/Llama-3.2-1B-Instruct pytest tests/models/quantization/test_per_token_kv_cache.py -v -s
pytest tests/kernels/attention/test_attention_selector.py -k 'per_head_quant_scales_backend_selection or flash_attn_rejects_int4_kv_cache' -v
pytest tests/v1/worker/test_gpu_model_runner.py -k 'reshape_kv_cache_tensors_handles_padded_attention_pages or update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride' -v
pytest tests/v1/core/test_kv_cache_utils.py -k 'page_size_padded or unify_kv_cache_spec_page_size' -v

Smoke-tested non-eager serving on Gemma4 26B and 31B (TP=2) with --kv-cache-dtype int4_per_token_head, VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1, -O1, and --gpu-memory-utilization 0.90. Tested text, multilingual, long-context, MM, and multi-image inputs.

Qwen 3.5 Benchmark Summary

The first two columns below are the reference baseline numbers from #36893 on non-quantized Qwen weights.

The AutoRound columns are the local control benchmark added for this PR to show the existing int8_per_token_head path is still working and to compare it with the new int4_per_token_head path.

The FP16/INT8 baseline columns are quoted from #36893. Local reruns on 2x RTX 3090 focus on AutoRound INT8 vs AutoRound INT4.

Local AutoRound benchmark conditions:

  • TP=2
  • VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1
  • -O1
  • prefix caching disabled
  • random benchmark dataset, input_len=1024, output_len=128, num_prompts=10
  • 35B MoE used --gpu-memory-utilization 0.85
  • 27B dense used --gpu-memory-utilization 0.87 --max-cudagraph-capture-size 128

35B MoE (Qwen3.5-35B-A3B)

MetricFP16 (#36893)INT8 (#36893)AutoRound INT8AutoRound INT4Delta (INT4 vs AutoRound INT8)
Requests/s4.704.801.920.48-74.9%
Output tokens/s601.63614.76122.7930.79-74.9%
Mean TTFT (ms)4897.014906.672510.5611362.05+352.6%
Mean TPOT (ms)123.18119.6442.23148.96+252.8%

27B Dense (Qwen3.5-27B)

MetricFP16 (#36893)INT8 (#36893)AutoRound INT8AutoRound INT4Delta (INT4 vs AutoRound INT8)
Requests/s1.591.590.710.28-61.4%
Output tokens/s203.16204.1291.3835.24-61.4%
Mean TTFT (ms)24291.1724183.786720.6322027.65+227.8%
Mean TPOT (ms)218.50217.4456.93112.05+96.8%

The slowdown is expected because AutoRound model quantization and int4 KV cache are both active at the same time.

Notes

  • scale caches initialized with 1.0 so partially-written blocks dequantize to safe values instead of garbage
  • profiling-time padded-page reshape now uses as_strided, same as the runtime path

Changed files

  • tests/kernels/attention/test_attention_selector.py (modified, +37/-0)
  • tests/models/quantization/test_per_token_kv_cache.py (modified, +19/-4)
  • tests/quantization/test_per_token_kv_cache.py (modified, +523/-108)
  • tests/v1/core/test_kv_cache_utils.py (modified, +58/-0)
  • tests/v1/worker/test_gpu_model_runner.py (modified, +84/-0)
  • vllm/config/cache.py (modified, +1/-0)
  • vllm/utils/torch_utils.py (modified, +1/-0)
  • vllm/v1/attention/backends/flash_attn.py (modified, +1/-1)
  • vllm/v1/attention/backends/triton_attn.py (modified, +161/-37)
  • vllm/v1/attention/ops/triton_reshape_and_cache_flash.py (modified, +288/-0)
  • vllm/v1/attention/ops/triton_unified_attention.py (modified, +290/-14)
  • vllm/v1/core/kv_cache_utils.py (modified, +9/-0)
  • vllm/v1/kv_cache_interface.py (modified, +63/-11)
  • vllm/v1/worker/gpu/attn_utils.py (modified, +15/-1)
  • vllm/v1/worker/gpu_model_runner.py (modified, +20/-6)
RAW_BUFFERClick to expand / collapse

VLLM already supports FP8 quantization for the KV cache. This path reduces the memory footprint and allows keeping more tokens in cache, but it relies on per-tensor scales that either come from the checkpoint or are calculated during calibration. The quality cost is quite bounded and easy to reason about because the scheme already exists and its assumptions are well known.

The INT8 per-token proposal shifts this balance. K and V still occupy one byte per element, just like in FP8, but instead of reusing a global per-tensor scale, a dynamic per-token scale is computed at the time of writing into the cache. This improves independence from the checkpoint, but also introduces an extra cost: besides the quantized tensor, auxiliary float32 buffers must be allocated to store these scales. Therefore, it is not enough to "discover" these buffers at runtime; they must be included from the beginning in profiling and memory planning.

Motivation

The main idea is to offer a more memory-efficient KV cache option that requires no additional configuration. Compared to storing the cache in bf16 or fp16, K and V take up approximately half the space. And unlike FP8, there is no need for the model to bring pre-calibrated scales or to run a prior calibration phase: the kernel obtains the scale directly from the data it is writing.

This makes the mode especially attractive as a "plug and play" option. The user would only need to set kv_cache_dtype="int8_per_token" and the rest should be resolved within the backend. This zero-config behavior is a significant part of the proposal's value, not just an implementation detail.

Design

1. Enum dispatch instead of string matching

One of the foundational changes is moving away from the kv_cache_dtype.startswith("fp8") checks scattered throughout the code. Instead, a KVQuantMode enum is introduced in kv_cache_interface.py with three active modes: NONE, FP8, and PER_TOKEN.

The idea is for the system to reason about a semantic mode, not string prefixes. get_kv_quant_mode() handles mapping the text dtype to this enum, checking suffixes before prefixes to avoid future ambiguities. From there, the legacy is_quantized_kv_cache() logic is rebuilt on top of the enum and kept exported for compatibility. This same quantization mode is propagated to all KVCacheSpec instances (FullAttentionSpec, SlidingWindowSpec, etc.).

2. Explicitly described auxiliary buffers

Until now, KV cache auxiliary memory was treated too flatly, as if an aggregated integer was enough. For per-token, this is no longer sufficient, because what matters is not just "how many more bytes" there are, but which buffers exist, with what dtype, and what shape they are allocated per block.

KVCacheSpec now exposes auxiliary_buffer_specs, a list of AuxBufferSpec(name, dtype, shape_per_block). In PER_TOKEN, the specification is straightforward: two float32 buffers, one for k_scale_cache and another for v_scale_cache, both with shape (block_size,). Thus, there is a single source of truth to know what is allocated and how much it costs.

3. Allocation and binding from the start

During initialize_kv_cache_tensors(), the model runner calls _allocate_auxiliary_buffers(), which iterates through the attention groups, queries auxiliary_buffer_specs, and allocates tensors on the device with shape (num_blocks, *shape_per_block). After allocation, each layer receives them through a new bind_auxiliary_buffers() method in AttentionImpl.

Backend Extensibility

Triton is the first target backend because it already functions as a relatively self-contained attention backend within vLLM's abstraction. This gives us full control over the cache's write path and the attention kernels that subsequently read it.

However, the design lays the groundwork for future modes and backends. PER_TOKEN_GROUP and NVFP4 are declared as future modes in the enum. The explicit buffer specifications mean PER_TOKEN_GROUP would easily allocate two float32 buffers with shape (block_size, num_kv_heads, ceil(head_size / group_size)), and NVFP4 could request block-scales and global-scales without requiring core architectural rewrites.

Memory and Scope

Realistic memory accounting

Once explicit auxiliary buffers exist, they also need to be properly accounted for when calculating the number of blocks. kv_cache_utils.py is updated to include this memory when deciding how many pages fit into the available budget.

In the uniform page size path, the maximum aux_per_block among groups is taken and added to the page_size before dividing the available memory. In the per-layer size path, total_aux is added to the denominator along with page_size_bytes. The goal is to have the block count reflect the actual footprint and not just that of the quantized KV tensor.

Implementation and Scope

Triton write path

For the PER_TOKEN mode, the write path is separated from the FP8 one with a new kernel: _reshape_cache_per_token. FP8 starts from an already known scale and simply applies it; per-token needs to inspect the token data, compute an absmax, derive a dynamic scale, and then quantize.

The kernel launches one instance per token, computes a global absmax over heads and head_size, derives scale = max(absmax / QUANT_MAX, 1e-6), saves that scale in k_scale_cache[block, slot], and quantizes using tl.clamp(val / scale, QUANT_MIN, QUANT_MAX). The kernel is parameterized using QUANT_MAX and QUANT_MIN, leaving the dispatch ready for other 1-byte formats in the future.

Triton read path

In TritonAttentionImpl.forward(), the previous FP8 logic is reorganized as an explicit three-way dispatch based on self.kv_quant_mode:

  • PER_TOKEN: The cache is interpreted as torch.int8, and self._k_scale_cache and self._v_scale_cache are passed to unified_attention().
  • FP8: The cache is viewed as the corresponding FP8 dtype, expanded layer._k_scale and layer._v_scale are used.
  • NONE: The existing behavior is maintained.

Inside the Triton attention kernels, for PER_TOKEN, K and V are converted to Q.dtype upon loading, and the per-token scales are fetched from k_scale_cache_ptr and v_scale_cache_ptr using the physical block and offset, applied during the dot product accumulation.

Weight loading and naming

BaseKVCacheMethod.process_weights_after_loading() adds an early return for per-token modes. Checkpoint scales are forced to 1.0 and deleted from the layer, as actual scales are dynamically generated in the write kernel. The exposed dtype is int8_per_token, registered across the required mappings and documentation.

Collateral changes and limitations

  • Encoder Attention: Validation rejects any quantized KV cache using kv_quant_mode != KVQuantMode.NONE.
  • Fused RoPE + KV cache: fused_rope_kvcache_supported() returns False for PER_TOKEN because that path has no way to write the auxiliary scale buffers.

Feedback Period

No response

CC List

@mgoin @LucasWilkinson

Any Other Things

PR: #36893

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

Fix Plan

To add a new quantized KV cache format, follow these steps:

  1. Register the format:
    • Add the new dtype to _DTYPES_WITH_PER_TOKEN_SCALES in kv_cache_interface.py.
    • Add the prefix to _PER_TOKEN_SCALE_PREFIXES in backend.py.
  2. Add dtype to config:
    • Update CacheDType in config/cache.py to include the new dtype.
  3. Register in target backend(s):
    • Add the new dtype to supported_kv_cache_dtypes in the target backend (e.g., triton_attn.py).
  4. Write format-specific quantization kernel:
    • Create a new function (e.g., triton_reshape_and_cache_flash_nvfp4_per_token) to compute scales and quantize data for the new format.
  5. Update backend's do_kv_cache_update:
    • Add an elif branch to handle the new dtype and call the format-specific quantization kernel.
  6. Add kernel flags for attention dequantization:
    • Update the Python wrapper (e.g., triton_unified_attention.py) to set a flag for the new format.
    • Update the Triton kernel signature to include the new flag.

Example code for adding NVFP4 support:

# kv_cache_interface.py
_DTYPES_WITH_PER_TOKEN_SCALES: set[torch.dtype] = {torch.int8, torch.nvfp4}

# backend.py
_PER_TOKEN_SCALE_PREFIXES: tuple[str, ...] = ("int8", "nvfp4")

# config/cache.py
CacheDType = Literal["auto", ..., "int8", "nvfp4"]

# triton_attn.py
supported_kv_cache_dtypes = [..., "int8", "nvfp4"]

# triton_reshape_and_cache_flash.py
def triton_reshape_and_cache_flash_nvfp4_per_token(
    key, value, key_cache, value_cache,
    k_scale_cache, v_scale_cache, slot_mapping,
):
    # Format-specific: compute scale, quantize to nvfp4, write scale cache
    ...

# triton_attn.py
elif self.kv_cache_dtype.startswith("nvfp4"):
    key_cache = key_cache.view(self.nvfp4_dtype)
    value_cache = value_cache.view(self.nvfp4_dtype)
    k_sc, v_sc = self.ensure_per_token_scale_caches(key_cache)  # FREE
    triton_reshape_and_cache_flash_nvfp4_per_token(
        key, value, key_cache, value_cache, k_sc, v_sc, slot_mapping,
    )
    return

# triton_unified_attention.py
use_nvfp4_kv = k.dtype == torch.nvfp4

# Triton kernel signature
USE_NVFP4_K

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING