vllm - ✅(Solved) Fix [RFC][NixlConnector]: Add support for hybrid SSM-FA models [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#36780Fetched 2026-04-08 00:34:46
View on GitHub
Comments
0
Participants
1
Timeline
16
Reactions
0
Participants
Timeline (top)
mentioned ×6subscribed ×6cross-referenced ×3labeled ×1

Root Cause

Why do we need things in physical plane to begin with? That is mainly because for heterogeneous setup we need to index into the num_head (H) dimension of the kv cache layout directly. Using logical sizes only would prevent this, as one logical block (our unit) is encompassing ratio physical_blocks, so inner splitting would be overcomplicated.

Fix Action

Fixed

PR fix notes

PR #36687: [PD][Nixl] Add support for hybrid SSM-FA models

Description (problem / solution / changelog)

For a comprehensive description of the changes proposed here, check out the corresponding RFC https://github.com/vllm-project/vllm/issues/36780.

This PR adds support for hybrid SSM-based models such as nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 with NixlConnector, enabling KVCache transfer of both FA and Mamba states in disaggregated setups. Currently it only supports Homogeneous TP sizes on both P and D.

Note that we're only transferring actual mamba states and skipping the padding that may be present, as that might have non-trivial size.

UPDATE: re this change"

- curr_tensor_size_bytes = cache.numel() * cache.element_size()
+ curr_tensor_size_bytes = num_blocks * physical_page_size

in this PR I am trying to further move away from relying on tensor views while trying to unify usage in code of kv_cache_config as single source of truth. This is also necessary for Mamba-like models in which tensors (cache above) gives the unpadded tensor size, which doesn't reflect the num_blocks * physical_page_size, as one would need to take into account padding manually.

Important notes

  • TP > 1 currently require --no-async-scheduling to run correctly. @ZhanqiuHu and I identified a synchronization issue where states may be transferred in a corrupted form, leading to high variance in evaluations. Will address separately as that is likely unrelated to SSMs.
  • @ZhanqiuHu has identified an issue with current PD workflow in which we're recomputing the first token on D, leading to burning-in that extra step into the SSM state in-place.

Test with

Enable HMA experimental support with --no-disable-hybrid-kv-cache-manager:

# usual P/D command
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8
--trust-remote-code \
--block-size 64 \
--no-enable-prefix-caching \
--no-disable-hybrid-kv-cache-manager \
 --mamba-ssm-cache-dtype float16 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# usual toy_proxy_server.py command

or

HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh

or check out unit tests added with this PR.

Results from running consecutive full lm-eval runs with no prefix caching:

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5444|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5345|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8340|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5398|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5428|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8332|±  |0.0103|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5557|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8506|±  |0.0098|

TODO

  • Address kernel<>logical block size mismatch
  • Benchmark

Changed files

  • tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh (modified, +8/-0)
  • tests/v1/kv_connector/nixl_integration/test_accuracy.py (modified, +1/-0)
  • tests/v1/kv_connector/unit/test_nixl_connector.py (modified, +76/-45)
  • tests/v1/kv_connector/unit/test_nixl_connector_hma.py (modified, +112/-0)
  • vllm/distributed/kv_transfer/kv_connector/utils.py (modified, +38/-8)
  • vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py (modified, +1/-1)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +351/-112)

Code Example

from kv_cache_utils (group_size = max(len(group.layer_names) for group in kv_cache_groups)

# General Case:
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.

# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)

# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2

---

# SSM padding not visible through tensor view

torch.Size([18525, 3, 3328])      # conv_state
torch.Size([18525, 48, 64, 128])  # ssm_state

torch.Size([18525, 400, 4, 128])  # FA KV, note the block_size 400 dim

---

(ssm conv)
  |   |          Tensor
  +-------->+-------------+
      |     |             |
      +----------...      |

---

Assuming: M regions, N num_blocks

+------------------------------------------------------+<----nixl_xfer_handle
|  # FA descriptors                                    |
|                                                      |
|  Region 0                                            |
|    FA_block_desc_K0                                  |
|    ...                                               |
|    FA_block_desc_Kn                                  |
|    FA_block_desc_V0                                  |
|    ...                                               |
|    FA_block_desc_Vn                                  |
|                                                      |
|  ...                                                 |
|                                                      |
|  Region M                                            |
|    FA_block_desc_K0                                  |
|    ...                                               |
|    FA_block_desc_Kn                                  |
|    FA_block_desc_V0                                  | ^
|    ...                                               | | 
|    FA_block_desc_Vn                                  | | num_descs offset
|                                                      | |
|  --------------------------------------------------  | v
|  # Mamba descriptors                                 |
|                                                      |
|  Region 0                                            |
|    Mamba_block_desc_SSM0                             |
|    ...                                               |
|    Mamba_block_desc_SSMn                             |
|    Mamba_block_desc_Conv0                            |
|    ...                                               |
|    Mamba_block_desc_Convn                            |
|                                                      |
|  ...                                                 |
|                                                      |
|  Region M                                            |
|    Mamba_block_desc_SSM0                             |
|    ...                                               |
|    Mamba_block_desc_SSMn                             |
|    Mamba_block_desc_Conv0                            |
|    ...                                               |
|    Mamba_block_desc_Convn                            |
|                                                      |
+------------------------------------------------------+

---

- FA tensor: `[num_logical * ratio, 2, kernel_block_size, heads, head_dim]` -- e.g. `[780390, 2, 16, 2, 128]`
    
- SSM tensor: `[num_logical, state_dim, hidden]` -- e.g. `[2990, 3, 6144]`
RAW_BUFFERClick to expand / collapse

Motivation.

Problem Statement

Supporting hybrid models that combine FullAttention (FA) and Mamba-style
SSM layers introduces several challenges for the KV connector:

  • FA and Mamba layers use different internal state layouts
    (K/V vs Conv/SSM).
  • Kernel constraints may require physical block sizes that differ
    from the logical block abstraction used by the block manager.

As a result, FA and Mamba layers must be able to index the same underlying KV cache tensor while using different block descriptor layouts.

Proposed Design

We introduce two logical descriptor views over the same registered memory regions:

  • Current descriptor view (used by non-Mamba layers)

    • Descriptors correspond to K/V blocks
  • Mamba descriptor view (used by Mamba layers)

    • Descriptors correspond to Conv and SSM state blocks

Both descriptor sets reference the same underlying tensor but use different offsets and sizes.

The descriptor lists are stored continuously, allowing the existing block_id → desc_id mapping logic to be extended with a simple index shift.

All changes proposed here are not meant to modify the existing workflow for "regular" models, but rather extend it for this specific case.

PR here https://github.com/vllm-project/vllm/pull/36687.

Proposed Change.

Prerequisite

This PR builds on the recently introduced HMA interface for NIXL https://github.com/vllm-project/vllm/pull/35758.

HMA

Hybrid Memory allocator enables efficient kv cache management for hybrid models, as in any type of model that combines multiple attention types (FA/SW/Linear..). HMA groups layers with the same attention type together. This leverages the fact that many models follow a repeating pattern such as 1:n ratios (e.g., three sliding-window layers followed by one full-attention layer)

Some invariants

  • all groups have the same size
    • padding layers are added to match
  • KVCache tensors can be viewed as 2d num_blocks x page_size, where page_size=block_size*num_heads*head_dim
  • HMA ensures that all groups have equal page_size (in bytes)
    • one way to ensure that when dims do not match is to "bump" the block_size dims to a common multiple
    • can only do that if max_page_size % layer_page_size

HMA uses memory pooling

HMA implements memory pooling across groups:

  • A single kv cache tensor is shared across groups!
    • all layers at position layer0 in each group share tensor0 This has several implications:
  • only tensors referenced by group0 need to be registered in NIXL
  • other tensors contain replicated addresses and can be skipped

Therefore we move from registering num_layers → num_regions

Consequences:

  • fewer regions are registered
  • each region typically contains more blocks

Importantly:

  • Block IDs across groups never overlap, since they refer to different offsets in the shared tensor.
    • eg block_ids=[[1, 2], [3]]->blocks [1, 2] need to be read across all regions, same for [3].
from kv_cache_utils (group_size = max(len(group.layer_names) for group in kv_cache_groups)

# General Case:
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.

# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)

# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2

SSM

block_size is inherently one, there's no concept of tokens as the state is collapsed to a fixed-sized representation. This makes it tricky to operate with FA that reason in terms of blocks. The way this is done is in practice is to assume block_size for SSM layers to be 1. This results in a page_size which is much bigger than what we usually have for FA layers.

To address that, block_size for FA layers is scaled until its size is bigger than that of Conv+temporal state. SSM is then padded accordingly to satisfy the constant page_size invariant.

Example shape

# SSM padding not visible through tensor view

torch.Size([18525, 3, 3328])      # conv_state
torch.Size([18525, 48, 64, 128])  # ssm_state

torch.Size([18525, 400, 4, 128])  # FA KV, note the block_size 400 dim

The diagram below summarizes the layout (padding in SSM is not shown here):

![[Pasted image 20260310181323.png]]

Note that unlike K and V, SSM and Conv states are of different sizes (and may include additional padding).

Also note that the FI layout is assumed to be used for all FullAttention layers, as that matches with how conv|ssm are layed out and ensures no data corruption when overwriting blocks (Mamba blocks can be re-used by FA at any time).

Nixl work

The NixlConnector workflow in terms of handling blocks can be summarized to the following:

  • a. register memory regions
  • b. register local xfer handles and block descriptors (our unit of transfer is a block not a region)
  • c. register remote xfer handles and block descriptors (during handshake)
  • d. map block ids to descriptor ids (_get_block_descs_ids)

Registering two descriptors "views"

Going through the workflow, registering regions (a.) can be done similarly to what we already do for dense models, registering the whole shared memory pooled KVCache tensor once. Mind that the actual tensors passed to register_kv_caches is a tuple (ssm, conv), each indexing the same shared tensor at offseted positions. To avoid inefficient registration of overlapping areas, we use the base address pointer provided by ssm

(ssm conv)
  |   |          Tensor
  +-------->+-------------+
      |     |             |
      +----------...      |

Note that this registration process is very much akin to what we do for FI-like layouts, where we register the whole tensor and then split descriptors on K/V dim. The main difference is that K/V have identical sizes, while conv/ssm do not.

Also, more importantly, both FA and SSM layers will need to be able to index into these regions with descriptors. To address that (moving to step b. and specular c.), we propose creating two separate sets of block descriptors for the same tensor: one for FullAttn layers, the other to be used by LinearAttn layers, like in the diagram below:

Assuming: M regions, N num_blocks

+------------------------------------------------------+<----nixl_xfer_handle
|  # FA descriptors                                    |
|                                                      |
|  Region 0                                            |
|    FA_block_desc_K0                                  |
|    ...                                               |
|    FA_block_desc_Kn                                  |
|    FA_block_desc_V0                                  |
|    ...                                               |
|    FA_block_desc_Vn                                  |
|                                                      |
|  ...                                                 |
|                                                      |
|  Region M                                            |
|    FA_block_desc_K0                                  |
|    ...                                               |
|    FA_block_desc_Kn                                  |
|    FA_block_desc_V0                                  | ^
|    ...                                               | | 
|    FA_block_desc_Vn                                  | | num_descs offset
|                                                      | |
|  --------------------------------------------------  | v
|  # Mamba descriptors                                 |
|                                                      |
|  Region 0                                            |
|    Mamba_block_desc_SSM0                             |
|    ...                                               |
|    Mamba_block_desc_SSMn                             |
|    Mamba_block_desc_Conv0                            |
|    ...                                               |
|    Mamba_block_desc_Convn                            |
|                                                      |
|  ...                                                 |
|                                                      |
|  Region M                                            |
|    Mamba_block_desc_SSM0                             |
|    ...                                               |
|    Mamba_block_desc_SSMn                             |
|    Mamba_block_desc_Conv0                            |
|    ...                                               |
|    Mamba_block_desc_Convn                            |
|                                                      |
+------------------------------------------------------+

The rationale is that block_ids coming from manager can be freely allocated to either a FA or Mamba layer. Given the fact that we register K/V block descs separately, and that Conv/SSM have non-matching sizes independent from both K and V, having separate views allows to select the right block descriptor (add, len) based on whether the corresponding block_id belongs to a mamba group or not. The above can be implemented in practice by simply shifting the current block_id->desc_index mapping logic by num_descs positions.

SSM + kernel_block size

Similar problem has been also described here https://github.com/vllm-project/vllm/pull/28677.

Going back to HMA assumptions, we assume every kv cache tensor in storage to be able to be laid out as num_blocks * page_size . Due to kernel block_size requirements (such as https://github.com/vllm-project/vllm/blob/545d18d81bf11761e51c2b11a006573c2ae366c1/vllm/v1/attention/backends/flashinfer.py#L304) of backends, resulting physical tensor may actually be represented with a different number of blocks wrt what the logical manager sees. This creates two different num_blocks: a physical and a logical. Practically, in order to maintain kernel_block_size requirements, num_blocks may be scaled by a ratio, that is ratio=logical_block_size/physical_block_size, like so:

- FA tensor: `[num_logical * ratio, 2, kernel_block_size, heads, head_dim]` -- e.g. `[780390, 2, 16, 2, 128]`
    
- SSM tensor: `[num_logical, state_dim, hidden]` -- e.g. `[2990, 3, 6144]`

Where effectively one logical block includes ratio physical contiguous ones.

Therefore, to address this we allow for 2 separate num_blocks in code depending on whether we're dealing with a SSM region. This affects both block descs registration (b/c) leading to the introduction of N1 and N2 in the diagram above, but more importantly it also affects step d., as the mapping now has to take into account that different regions may use different num_blocks value.

Why do we need things in physical plane to begin with? That is mainly because for heterogeneous setup we need to index into the num_head (H) dimension of the kv cache layout directly. Using logical sizes only would prevent this, as one logical block (our unit) is encompassing ratio physical_blocks, so inner splitting would be overcomplicated.

Feedback Period.

No response

CC List.

@tlrmchlsmth @roikoren755 @robertgshaw2-redhat @ZhanqiuHu @orozery @tdoublep

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

Fix Plan

To address the challenges introduced by supporting hybrid models that combine FullAttention (FA) and Mamba-style SSM layers, we need to implement the following steps:

  • Introduce two logical descriptor views over the same registered memory regions:
    • Current descriptor view (used by non-Mamba layers)
    • Mamba descriptor view (used by Mamba layers)
  • Modify the block descriptor registration process to accommodate both FA and SSM layers
  • Update the block_id to desc_index mapping logic to account for the separate descriptor views

Code Changes

Here's an example of how the code changes could be implemented:

# Define the two descriptor views
class DescriptorView:
    def __init__(self, tensor, offset, size):
        self.tensor = tensor
        self.offset = offset
        self.size = size

class CurrentDescriptorView(DescriptorView):
    def __init__(self, tensor, offset, size):
        super().__init__(tensor, offset, size)

class MambaDescriptorView(DescriptorView):
    def __init__(self, tensor, offset, size):
        super().__init__(tensor, offset, size)

# Modify the block descriptor registration process
def register_block_descriptors(tensor, num_blocks, num_heads, head_dim):
    # Register FA block descriptors
    fa_block_descriptors = []
    for i in range(num_blocks):
        fa_block_descriptor = CurrentDescriptorView(tensor, i * num_heads * head_dim, num_heads * head_dim)
        fa_block_descriptors.append(fa_block_descriptor)

    # Register Mamba block descriptors
    mamba_block_descriptors = []
    for i in range(num_blocks):
        mamba_block_descriptor = MambaDescriptorView(tensor, i * num_heads * head_dim, num_heads * head_dim)
        mamba_block_descriptors.append(mamba_block_descriptor)

    return fa_block_descriptors, mamba_block_descriptors

# Update the block_id to desc_index mapping logic
def get_block_desc_id(block_id, num_descs):
    # Check if the block_id belongs to a Mamba group
    if block_id >= num_descs:
        # Shift the block_id to account for the separate descriptor views
        block_id -= num_descs
    return block_id

Verification

To verify that the fix worked, you can test the following scenarios:

  • Register a hybrid model that combines FA and SSM layers
  • Verify that the block descriptors are correctly registered for both FA and SSM layers
  • Test the block_id to desc_index mapping logic to ensure that it correctly handles both FA and SSM layers

Extra Tips

  • Make sure to update the documentation to reflect the changes made to the code
  • Consider adding additional logging or debugging statements to help diagnose any issues that may arise
  • Review the code changes to ensure that they do not introduce any performance regressions or other unintended consequences.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING