vllm - 💡(How to fix) Fix [RFC]: Standardize KV-cache Layouts [2 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fixed

Code Example

[num_layers, num_blocks, num_states, num_heads, <state_content>]

---

@dataclass
class KVCacheSpec:
    num_heads: int          # heads (1 if headless)
    tokens_per_state: int   # -1 infinite, 1 standard, N compressed
    state_content_size: int # bytes per state per head

---

Semantic:  [L, B, S, H, C]
Physical:  [L, B, S, H, C]  (identity)
stride_order = (0, 1, 2, 3, 4)

---

Semantic:  [L, B, S, H, C]
Physical:  [L, B, H, S, C]
stride_order = (0, 1, 3, 2, 4)

---

Semantic:  [L, B, S, H, C]
Physical:  [B, L, S, H, C]
stride_order = (1, 0, 2, 3, 4)

---

Semantic:  [L, B, S, H, C]
Physical:  [B, H, L, S, C]
stride_order = (1, 3, 0, 2, 4)

---

class KVCacheSpec:
    def slice_for_tp_transfer(self, my_tensor, my_tp, my_rank,
                               other_spec, other_tp,
                               other_rank) -> list[Tensor]:
        """Return slices of `my_tensor` that need to be transferred.
        Takes both sides' TP configs since tp_size is a deployment
        property, not part of the spec."""
        ...

---

def slice_for_tp_transfer(self, src_tensor, src_tp, src_rank,
                          dst_spec, dst_tp, dst_rank):

    src_heads = src_tp*self.num_heads
    dst_heads = dst_tp*dst_spec.num_heads

    if dst_heads > src_heads:
        dst_rank %= src_tp
    elif src_rank >= dst_tp:
        return []

    head_range = lambda tp_rank, num_heads: (tp_rank * num_heads, (tp_rank + 1) * num_heads)

    src_head_range = head_range(src_rank, self.num_heads)
    dst_head_range = head_range(dst_rank, dst_spec.num_heads)

    overlap = get_overlap(src_head_range, dst_head_range)
    if overlap is None:
        return []

    # Convert global overlap to local tensor indices
    local_slice = slice(overlap.start - src_head_range[0],
                        overlap.stop - src_head_range[0])
    return [src_tensor[:, :, :, local_slice]]

def get_overlap(range1, range2):
    # Calculate the intersection boundaries
    overlap_start = max(range1[0], range2[0])
    overlap_end = min(range1[1], range2[1])
    
    # Check if the boundaries form a valid range
    if overlap_start <= overlap_end:
        return slice(overlap_start, overlap_end)
    return None

---

def slice_for_tp_transfer(self, my_tensor, my_tp, my_rank,
                          other_spec, other_tp, other_rank):
    # tensor: [layers, blocks, 1, 1, conv_dim/tp, kernel-1]
    x_range = self._compute_x_range(my_tp, my_rank, other_tp, other_rank)
    b_range = self._compute_b_range(my_tp, my_rank, other_tp, other_rank)
    c_range = self._compute_c_range(my_tp, my_rank, other_tp, other_rank)
    return [
        my_tensor[:, :, 0, 0, x_range, :],  # x slice
        my_tensor[:, :, 0, 0, b_range, :],  # B slice
        my_tensor[:, :, 0, 0, c_range, :],  # C slice
    ]

---

# TP slicing (once at setup — determines which heads to transfer)
src_slices = src_spec.slice_for_tp_transfer(
    src_cache, src_tp, src_rank, dst_spec, dst_tp, dst_rank)
dst_meta = make_meta_kv_cache(dst_spec, num_blocks, dst_block_size)
dst_slices = dst_spec.slice_for_tp_transfer(
    dst_meta, dst_tp, dst_rank, src_spec, src_tp, src_rank)

# Per-block transfer (handles hetero block sizes naturally)
for src_states_slice, dst_states_slice in block_mapping:
    for src_s, dst_s in zip(src_slices, dst_slices):
        src_block = src_s[src_states_slice]
        dst_block = dst_s[dst_states_slice]
        for src_chunk, dst_chunk in jointly_contiguous_chunks(
                src_block, dst_block):
            self._rdma_transfer(src_chunk, dst_chunk)
RAW_BUFFERClick to expand / collapse

Motivation.

Problem:

Right now attention backends can semantically different and physically different KV-cache layouts. This leads to messy KV-connector code (i.e. lots of is_mamba and is_mla flags), and creates a tight coupling between KV-connector and the attention backends, e.g. https://github.com/vllm-project/vllm/blob/4140faa4a51d42cb9618949bee28fd47682f611c/vllm/v1/attention/backends/flash_attn.py#L152-L171

This also leads to other related bugs like: https://github.com/vllm-project/vllm/pull/41657#issuecomment-4400641394

More details in: https://docs.google.com/document/d/1-TnWZf8jI6nWgE-xaQkgsWy130CVUqqt1NYqY-mAOe4/edit?usp=sharing

Proposed Change.

Proposed: Unified Semantic KV Layout

Standard Semantic Shape

[num_layers, num_blocks, num_states, num_heads, <state_content>]

Where:

  • num_states — token positions per block (or 1 for recurrent state)
  • num_heads — heads (or 1 for headless backends like MLA)
  • <state_content> — backend-specific, always contiguous

Every backend maps to this shape:

Backendnum_statesnum_headsstate_content
GQAblock_sizenum_kv_heads[2, head_dim]
DeepSeek V4block_size/4num_kv_heads[2, head_dim]
MLAblock_size1[latent_size]
Mamba2 Conv11[conv_dim/tp, kernel-1]
Mamba2 SSM1num_heads[head_dim, state_size]

KVCacheSpec Properties

@dataclass
class KVCacheSpec:
    num_heads: int          # heads (1 if headless)
    tokens_per_state: int   # -1 infinite, 1 standard, N compressed
    state_content_size: int # bytes per state per head

NOTE: sliding_window and attention_chunk_size will still be needed in the spec but we plan to refactor those away in a subsequent attention backends refactor .

num_states is derived from block_size / tokens_per_state at allocation time — it's not known when the spec is constructed.

Backendnum_headstokens/statenum_states (derived)
GQAnum_kv_heads1block_size
DeepSeek V4num_kv_heads4block_size/4
MLA11block_size
Mamba2 Conv1-11
Mamba2 SSMnum_heads-11

Stride Order: Semantic vs Physical Layout

The semantic shape is always [layers, blocks, states, heads, <content>]. The physical layout (memory order) is controlled by a stride_order permutation.

NHD (states outer, default):

Semantic:  [L, B, S, H, C]
Physical:  [L, B, S, H, C]  (identity)
stride_order = (0, 1, 2, 3, 4)

Blocks are contiguous in memory. Good for block-level transfers, hetero-block-size disagg.

HND (heads outer):

Semantic:  [L, B, S, H, C]
Physical:  [L, B, H, S, C]
stride_order = (0, 1, 3, 2, 4)

Head slices are contiguous in memory. Good for head-level TP transfers, hetero-TP disagg.

BLSHC (blocks outer, layers inner):

Semantic:  [L, B, S, H, C]
Physical:  [B, L, S, H, C]
stride_order = (1, 0, 2, 3, 4)

All layers' data for one block is contiguous. Good for cross-layer block transfers — the connector can transfer an entire block across all layers in a single RDMA read.

BHLSC (blocks outer, heads outer, layers inner):

Semantic:  [L, B, S, H, C]
Physical:  [B, H, L, S, C]
stride_order = (1, 3, 0, 2, 4)

For a given block and head, all layers are contiguous. Good for cross-layer TP transfers — one contiguous region per (block, head) across all layers.


Spec-Driven TP Slicing

Instead of connector byte math, specs provide their own slicing:

class KVCacheSpec:
    def slice_for_tp_transfer(self, my_tensor, my_tp, my_rank,
                               other_spec, other_tp,
                               other_rank) -> list[Tensor]:
        """Return slices of `my_tensor` that need to be transferred.
        Takes both sides' TP configs since tp_size is a deployment
        property, not part of the spec."""
        ...

Each spec subclass knows how to slice its own tensor. The signature takes the other spec so it can compute the correct overlap — GQA slices by heads, Mamba slices by x/B/C components, MLA may be a single unsliced latent.

GQA Slicing

Head slicing handling replication on either side (or both)

def slice_for_tp_transfer(self, src_tensor, src_tp, src_rank,
                          dst_spec, dst_tp, dst_rank):

    src_heads = src_tp*self.num_heads
    dst_heads = dst_tp*dst_spec.num_heads

    if dst_heads > src_heads:
        dst_rank %= src_tp
    elif src_rank >= dst_tp:
        return []

    head_range = lambda tp_rank, num_heads: (tp_rank * num_heads, (tp_rank + 1) * num_heads)

    src_head_range = head_range(src_rank, self.num_heads)
    dst_head_range = head_range(dst_rank, dst_spec.num_heads)

    overlap = get_overlap(src_head_range, dst_head_range)
    if overlap is None:
        return []

    # Convert global overlap to local tensor indices
    local_slice = slice(overlap.start - src_head_range[0],
                        overlap.stop - src_head_range[0])
    return [src_tensor[:, :, :, local_slice]]

def get_overlap(range1, range2):
    # Calculate the intersection boundaries
    overlap_start = max(range1[0], range2[0])
    overlap_end = min(range1[1], range2[1])
    
    # Check if the boundaries form a valid range
    if overlap_start <= overlap_end:
        return slice(overlap_start, overlap_end)
    return None

Mamba Conv Slicing: 3 Slices for x/B/C

def slice_for_tp_transfer(self, my_tensor, my_tp, my_rank,
                          other_spec, other_tp, other_rank):
    # tensor: [layers, blocks, 1, 1, conv_dim/tp, kernel-1]
    x_range = self._compute_x_range(my_tp, my_rank, other_tp, other_rank)
    b_range = self._compute_b_range(my_tp, my_rank, other_tp, other_rank)
    c_range = self._compute_c_range(my_tp, my_rank, other_tp, other_rank)
    return [
        my_tensor[:, :, 0, 0, x_range, :],  # x slice
        my_tensor[:, :, 0, 0, b_range, :],  # B slice
        my_tensor[:, :, 0, 0, c_range, :],  # C slice
    ]

Same interface — spec knows its own structure.


Connector Becomes Generic

The connector needs both source and destination specs — if they have different stride orders (e.g., prefill uses HND, decode uses NHD), the source and destination slices map to different physical memory regions. Both specs produce slices in the same semantic order, so they pair directly. No backend-specific code in the connector.

# TP slicing (once at setup — determines which heads to transfer)
src_slices = src_spec.slice_for_tp_transfer(
    src_cache, src_tp, src_rank, dst_spec, dst_tp, dst_rank)
dst_meta = make_meta_kv_cache(dst_spec, num_blocks, dst_block_size)
dst_slices = dst_spec.slice_for_tp_transfer(
    dst_meta, dst_tp, dst_rank, src_spec, src_tp, src_rank)

# Per-block transfer (handles hetero block sizes naturally)
for src_states_slice, dst_states_slice in block_mapping:
    for src_s, dst_s in zip(src_slices, dst_slices):
        src_block = src_s[src_states_slice]
        dst_block = dst_s[dst_states_slice]
        for src_chunk, dst_chunk in jointly_contiguous_chunks(
                src_block, dst_block):
            self._rdma_transfer(src_chunk, dst_chunk)

TP slicing is computed once (which heads overlap). The `block_mapping` hetero block sizes. jointly_contiguous_chunks decomposes each pair into regions contiguous in both physical layouts; and will naturally take advantage (if possible) of layers that are packed in an interleaved fashion . Destination offsets come from a meta tensor (device='meta', no GPU memory).

Connector is backend-agnostic.

Feedback Period.

No response

CC List.

@heheda12345 @benchislett @NickLucche @ZhanqiuHu

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - 💡(How to fix) Fix [RFC]: Standardize KV-cache Layouts [2 pull requests]