vllm - ✅(Solved) Fix [Tracking Issue]: Mamba Heterogeneous TP for NIXL P/D Disaggregation [5 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37638Fetched 2026-04-08 01:04:26
View on GitHub
Comments
0
Participants
1
Timeline
0
Reactions
0
Author
Participants

PR fix notes

PR #36957: [NIXL][Mamba][1/N] Heterogeneous TP: full conv transfer + local extract

Description (problem / solution / changelog)

Note: This draft is based on top of #36687 (NickLucche/nixl-ssm-rebase), which adds homogeneous TP support for hybrid SSM-FA models. To see only the heterogeneous TP changes, view the last commit.

This exploration was conducted with assistance from Cursor (claude-4.6-opus-high).

This extends #36687 to support heterogeneous tensor parallelism (different TP sizes on prefiller vs decoder) for hybrid SSM-FA models (e.g., Nemotron-H) in NIXL-based P/D disaggregation.

#36687 adds P/D support for hybrid models with homogeneous TP only (tp_ratio == 1). This draft lifts that restriction, enabling configurations like P_TP=1, D_TP=2.

Progress

  • Found and fixed some issues blocking heterogeneous TP for hybrid SSM-FA models
  • Documented design choices (Options A–D)
  • Implemented experimental code of staging buffer approach (Option B: staging buffer; read full conv_states then split and truncate; most straightforward, but may not be the best)
  • Validated Option B with lm_eval gsm8k: 1p2d accuracy matches standalone baseline
  • Experimental implementation for Option D: permute to change conv_states layout
  • Experimental implementation for Option A: separate RDMA reads for x, B, C
  • Option C: Claude proposed this, but I think it's a little weird and too invasive to the code.

Approaches Considered (In Progress)

OptionApproachConv StateSSM StateNotes
ASeparate RDMA per component3 RDMA reads (x, B, C)rank_offsetExact transfer, no staging buffer, but 3x descriptors
BFull conv read + local extract1 RDMA read -> staging buffer -> post-processrank_offsetImplemented this exploration
CRecompute conv on decoderSkip conv transfer entirelyrank_offsetMost bandwidth-efficient, but needs model weight access
DPermute to group-interleaved layoutRearrange on sender siderank_offsetImplemented:

All options use rank_offset for SSM (temporal) state since its head dimension is contiguous and cleanly splittable, identical to attention heads.

Option B Results (So far did Option B, should be most straightforward to implement, doesn't touch homogeneous TP cases)

Model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8, max_model_len=8192, gpu_mem_util=0.8, num_concurrent=100.

lm_eval gsm8k 5-shot accuracy (N=1319, greedy, test script):

Configgsm8k strict-matchNotes
Standalone TP=20.8324Baseline (expected ~0.84)
1p1d (P_TP=1, D_TP=1)0.8234Homogeneous TP
1p2d (P_TP=1, D_TP=2)0.8453Heterogeneous TP

TODO — remaining validation:

  • More TP combos
  • Test cases
  • Failures and corner cases
  • Logging?

Summary of Changes (nixl_connector.py)

  • Per-engine descriptor count tracking
  • Correct Mamba SSM state remote address
  • Staging buffer + post-process for Mamba conv state (HMA multi-region)
  • Correct attention V address for HMA hetero TP

Limitations

  • Only tested P_TP=1 -> D_TP=2; reverse direction not yet implemented
  • Staging buffer adds ~1.68 GiB on D (2.1% of H100 80GB) on top of gpu_memory_utilization
  • Diagnostic logging is verbose (should be gated or removed before production)

Related

  • #36687 — Homogeneous TP support for hybrid SSM-FA models
  • #36780 — NixlConnector hybrid SSM-FA design

<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh (modified, +7/-0)
  • tests/v1/kv_connector/nixl_integration/test_accuracy.py (modified, +1/-0)
  • tests/v1/kv_connector/unit/test_nixl_connector_hma.py (modified, +112/-0)
  • vllm/distributed/kv_transfer/kv_connector/utils.py (modified, +41/-8)
  • vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py (modified, +1/-1)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +775/-108)

PR #37603: [NIXL][Mamba][2/N] Heterogeneous TP : chunk-interleaved layout

Description (problem / solution / changelog)

Related: https://github.com/vllm-project/vllm/pull/36957

Purpose

Enable prefill/decode disaggregation with different tensor parallelism sizes (e.g., P_TP=2, D_TP=4) for hybrid attention+Mamba models.

High-level idea

related

Original: (conv_rows, conv_dim) — row-major, components concatenated.

The transformation has two conceptual steps: grouping then transpose.

Original layout (conv_rows=3, conv_dim = x_dim + B_dim + C_dim):

  Row 0: [x0  x1  x2  x3  x4  x5  x6  x7  | B0  B1  | C0  C1 ]
  Row 1: [x0' x1' x2' x3' x4' x5' x6' x7' | B0' B1' | C0' C1']
  Row 2: [x0" x1" x2" x3" x4" x5" x6" x7" | B0" B1" | C0" C1"]
          ◄──────── x_dim = 8 ──────────────►◄─ B=2 ──►◄─ C=2 ──►

  All x columns first, then all B, then all C.
  A TP=2 shard needs [x0..x3 | B0 | C0] — NOT contiguous.

Step 1 — Grouped (columns rearranged, still row-major):

  Row 0: [x0  x1  x2  x3  | B0 | C0 ‖ x4  x5  x6  x7  | B1 | C1]
  Row 1: [x0' x1' x2' x3' | B0'| C0'‖ x4' x5' x6' x7' | B1'| C1']
  Row 2: [x0" x1" x2" x3" | B0"| C0"‖ x4" x5" x6" x7" | B1"| C1"]
          ◄──── chunk 0 ──────────►  ◄──── chunk 1 ──────────►

  TP=2 shards are now grouped, but NOT contiguous in flat memory
  (row-major means row 0 of chunk 0 is far from row 1 of chunk 0).

Step 2 — Transposed (each chunk flattened column-first):

  ┌─ chunk 0 ──────────────────────────────────────────────────────────────┐
  │ x0 x0' x0" │ x1 x1' x1" │ x2 x2' x2" │ x3 x3' x3" │ B0 B0' B0" │ C0 C0' C0" │
  ├─ chunk 1 ──────────────────────────────────────────────────────────────┤
  │ x4 x4' x4" │ x5 x5' x5" │ x6 x6' x6" │ x7 x7' x7" │ B1 B1' B1" │ C1 C1' C1" │
  └────────────────────────────────────────────────────────────────────────┘

  TP=2 → rank 0 reads chunk 0, rank 1 reads chunk 1 (1 contiguous read each)

Changes

  1. Remote kernel block addressing — Use the remote engine's physical-per-logical block ratio instead of the local one. P and D can have different ratios (e.g., P=261, D=131 for Nemotron-Nano 2p4d), so using the local ratio caused D to read from wrong memory offsets on P.

  2. Chunk-interleave permutation — Mamba conv state has interleaved X/B/C columns. With heterogeneous TP, a naive byte-offset RDMA read gives D the wrong subset. The permutation reorders P's conv blocks so each D rank's shard is a contiguous byte range. Inverse permutation on D restores the original layout.

  3. Mamba rank_offset — Mamba conv state is always TP-sharded (even when attention KV is replicated), so needs_rank_offset must be true for mamba groups regardless of indexes_into_remote.

  4. Skip mamba trimming — Partial prefix cache hit trimming must skip mamba groups. Their blocks represent full state (conv + ssm), not per-token data.

  5. V-split addressing — Correct the SSM temporal state offset for heterogeneous TP: base + block_offset + full_conv_size + rank*ssm_shard instead of base + block_offset + rank*conv_shard + full_conv_size. Only affects mamba — attention K/V are symmetric so the original formula happened to work.

Interacts with #37416 (conv state layout). Our code assumes SD layout (state_len, dim) with assertions that fail clearly if DS layout is used.

Test Plan

  • lm_eval 1p2d, 2p4d with Nemotron-Nano-30B-A3B-FP8
  • 1p4d, 2p1d, 4p2d, 4p1d, etc
  • n_groups=1 model
  • Unit tests for chunk permutation round-trip
  • E2E tests

Test Result

lm_eval

gsm8k 5 shot Temperature: 0.0 Model: NVIDIA-Nemotron-3-Nano-30B-A3B-FP8**

Setting: 1p2d Scores: 0.8544

Setting: 2p4d Scores: 0.8635


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • vllm/distributed/kv_transfer/kv_connector/v1/hetero_tp_conv_utils.py (added, +186/-0)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +279/-42)

PR #37635: [NIXL][Mamba][3/N] Heterogeneous TP: 3-read conv state transfer

Description (problem / solution / changelog)

Related: #36957, #37603, #37416

Purpose

Enable prefill/decode disaggregation with different tensor parallelism sizes (e.g., P_TP=1, D_TP=2 or P_TP=2, D_TP=4) for hybrid attention+Mamba models — alternative to the chunk-interleaved permutation approach in #37603.

High-level idea

Instead of permuting P's conv state and reading it in one RDMA transfer, D issues 3 separate RDMA reads for the x, B, C sub-projections directly from P's memory. This eliminates all P-side and D-side permutation logic.

Key requirement: DS conv state layout (VLLM_SSM_CONV_STATE_LAYOUT=DS, introduced in #37416), which stores conv state as (dim, state_len) so that x/B/C sub-projections are contiguous in memory along the dim axis.

Memory layout (DS, per block, per P rank with TP_P):

|------- x (d_inner/TP_P) -------|- B (gs/TP_P) -|- C (gs/TP_P) -|-- SSM --|

D rank j reads from P rank r = j // tp_ratio, offset = j % tp_ratio:
  1. x read: P's x_start + offset × x_D_bytes, size = x_D_bytes
  2. B read: P's B_start + offset × b_D_bytes, size = b_D_bytes
  3. C read: P's C_start + offset × c_D_bytes, size = c_D_bytes

Changes

New: conv_decomp_utils.py

  • ConvDecomp dataclass with per-rank x/B/C column counts and offset computation methods
  • derive_conv_decomp() — extracts decomposition from MambaSpec
  • compute_mamba_phys_ratio() — per-engine physical block ratio for HMA

Modified: nixl_connector.py

  1. 3-read descriptor registration_register_mamba_3read_local() and _register_mamba_3read_remote() create 4 desc regions per mamba layer (x, B, C, ssm) instead of the original 2 (conv, ssm)
  2. Remote kernel block addressing — use remote engine's physical-per-logical block ratio (_mamba_phys_ratio) instead of local. P and D can have different ratios under HMA.
  3. Mamba rank_offset — mamba conv state is always TP-sharded (even when attention KV is replicated due to num_kv_heads < tp_size), so local_offset always uses tp_rank % effective_ratio
  4. Skip mamba trimming — partial prefix cache hit trimming skips mamba groups (their blocks are full state, not per-token)
  5. Block_len validation — HMA hybrid models pad block_len to max(attn_page, mamba_page), so the linear tp_ratio scaling assumption is gated behind not _has_mamba
  6. tp_ratio < 0 — raises NotImplementedError (P_TP > D_TP not yet supported for Mamba)

Comparison with permute approach (#37603)

AspectPermute (1-read)3-Read
P-side workIn-place permutation per blockNone
D-side workInverse permutation after readNone
RDMA reads/block1 (conv) + 1 (ssm)3 (x,B,C) + 1 (ssm)
Descriptors/block24
Conv layout reqSD (default)DS (env var)
Code complexity~250 lines (perm indices)~100 lines (offset calc)

Interacts with #37416 (DS conv state layout). Code asserts DS layout is active at init.

Test Plan

  • lm_eval gsm8k 5-shot with Nemotron-Nano-30B-A3B-FP8
  • 1p2d and 2p4d hetero-TP configs
  • Quick sanity (coherent output) + full accuracy + KV hit rate + transfer error checks

Test Result

gsm8k 5-shot, Temperature: 0.0 Model: NVIDIA-Nemotron-3-Nano-30B-A3B-FP8

Configstrict-matchKV hit rateTransfer errorsStatus
1p2d0.845399.90%0PASS
2p4d0.844699.90%0PASS

Expected baseline: ~0.84

Changed files

  • vllm/distributed/kv_transfer/kv_connector/v1/mamba_conv_transfer_utils.py (added, +151/-0)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +232/-40)
  • vllm/envs.py (modified, +8/-0)
  • vllm/model_executor/layers/mamba/mamba_mixer.py (modified, +6/-1)
  • vllm/model_executor/layers/mamba/mamba_mixer2.py (modified, +6/-2)
  • vllm/model_executor/layers/mamba/mamba_utils.py (modified, +85/-17)
  • vllm/model_executor/layers/mamba/ops/causal_conv1d.py (modified, +0/-4)
  • vllm/model_executor/layers/mamba/short_conv.py (modified, +6/-1)

PR #36687: [PD][Nixl] Add support for hybrid SSM-FA models

Description (problem / solution / changelog)

For a comprehensive description of the changes proposed here, check out the corresponding RFC https://github.com/vllm-project/vllm/issues/36780.

This PR adds support for hybrid SSM-based models such as nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 with NixlConnector, enabling KVCache transfer of both FA and Mamba states in disaggregated setups. Currently it only supports Homogeneous TP sizes on both P and D.

Note that we're only transferring actual mamba states and skipping the padding that may be present, as that might have non-trivial size.

UPDATE: re this change"

- curr_tensor_size_bytes = cache.numel() * cache.element_size()
+ curr_tensor_size_bytes = num_blocks * physical_page_size

in this PR I am trying to further move away from relying on tensor views while trying to unify usage in code of kv_cache_config as single source of truth. This is also necessary for Mamba-like models in which tensors (cache above) gives the unpadded tensor size, which doesn't reflect the num_blocks * physical_page_size, as one would need to take into account padding manually.

Important notes

  • TP > 1 currently require --no-async-scheduling to run correctly. @ZhanqiuHu and I identified a synchronization issue where states may be transferred in a corrupted form, leading to high variance in evaluations. Will address separately as that is likely unrelated to SSMs.
  • @ZhanqiuHu has identified an issue with current PD workflow in which we're recomputing the first token on D, leading to burning-in that extra step into the SSM state in-place.

Test with

Enable HMA experimental support with --no-disable-hybrid-kv-cache-manager:

# usual P/D command
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8
--trust-remote-code \
--block-size 64 \
--no-enable-prefix-caching \
--no-disable-hybrid-kv-cache-manager \
 --mamba-ssm-cache-dtype float16 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# usual toy_proxy_server.py command

or

HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh

or check out unit tests added with this PR.

Results from running consecutive full lm-eval runs with no prefix caching:

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5444|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5345|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8340|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5398|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5428|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8332|±  |0.0103|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5557|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8506|±  |0.0098|

TODO

  • Address kernel<>logical block size mismatch
  • Benchmark

Changed files

  • tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh (modified, +8/-0)
  • tests/v1/kv_connector/nixl_integration/test_accuracy.py (modified, +1/-0)
  • tests/v1/kv_connector/unit/test_nixl_connector.py (modified, +76/-45)
  • tests/v1/kv_connector/unit/test_nixl_connector_hma.py (modified, +112/-0)
  • vllm/distributed/kv_transfer/kv_connector/utils.py (modified, +38/-8)
  • vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py (modified, +1/-1)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +351/-112)

PR #37416: [Kernel] Mamba support different layout for Conv state

Description (problem / solution / changelog)

Will fill in soon.

Changed files

  • vllm/envs.py (modified, +8/-0)
  • vllm/model_executor/layers/mamba/mamba_mixer.py (modified, +6/-1)
  • vllm/model_executor/layers/mamba/mamba_mixer2.py (modified, +9/-2)
  • vllm/model_executor/layers/mamba/mamba_utils.py (modified, +85/-17)
  • vllm/model_executor/layers/mamba/ops/causal_conv1d.py (modified, +0/-4)
  • vllm/model_executor/layers/mamba/short_conv.py (modified, +6/-1)

Code Example

P's KV:  |-- h0 --|-- h1 --|-- h2 --|-- h3 --|
          └─ D rank 0 ─┘    └─ D rank 1 ─┘       ← 1 contiguous read each ✓

---

P's conv: |-------- x --------|- B -|- C -|

D rank 0 needs: x[first half] + B[first half] + C[first half]  ← scattered, not contiguous ✗

---

P's conv: |-------- x --------|- B -|- C -|
               1 RDMA read (entire conv)
D staging:  |-------- x --------|- B -|- C -|
              extract my slice locally
D rank 0:   |--- x[0:half] ---|- B[0:half] -|- C[0:half] -|

---

P's conv (original):  |-------- x --------|- B -|- C -|
                     permute on P (chunk-interleaved)
P's conv (permuted):  |-- x₀ BC--|-- x₁ BC--|
                       └─ D rank 0 ─┘  └─ D rank 1 ─┘
                       1 RDMA read each (contiguous)
                     inverse permute on D
D rank 0:             |--- x₀ ---|- B-|- C-|

---

P's conv (DS layout): |-------- x --------|- B -|- C -|
                        ▲          ▲        ▲     ▲
                        │          │        │     │
D rank 0 reads:     read 1 ───┘          read 2  read 3
                   (first half of x)  (first half (first half
                                        of B)      of C)

                   3 RDMA reads, no permutation on either side
RAW_BUFFERClick to expand / collapse

Builds on the hybrid SSM-FA NIXL support (#36687, RFC #36780) to enable different TP sizes between prefill and decode engines for Mamba models.

Challenge

With hetero-TP (P_TP ≠ D_TP), each D rank RDMA-reads a slice of P's cache. Example: Prefill TP = 1; Decode TP = 2

Attention KV — all heads are equal size, so a flat split works:

P's KV:  |-- h0 --|-- h1 --|-- h2 --|-- h3 --|
          └─ D rank 0 ─┘    └─ D rank 1 ─┘       ← 1 contiguous read each ✓

Mamba conv state — x, B, C have different sizes (x_dim ≠ B_dim = C_dim), so a flat split is wrong:

P's conv: |-------- x --------|- B -|- C -|

D rank 0 needs: x[first half] + B[first half] + C[first half]  ← scattered, not contiguous ✗

Approaches explored (experimental)

Two approaches have been prototyped and validated with lm_eval gsm8k on Nemotron-Nano-30B-A3B-FP8. Neither is a comprehensive implementation yet — they are experimental explorations to evaluate trade-offs (code complexity, RDMA count, P/D-side compute, layout requirements).

  • Approach 1: Earlier exploration (#36957): full conv read + local staging buffer.
    • Requires extra GPU memory
    • Probably not ideal
P's conv: |-------- x --------|- B -|- C -|
               1 RDMA read (entire conv)
D staging:  |-------- x --------|- B -|- C -|
              extract my slice locally
D rank 0:   |--- x[0:half] ---|- B[0:half] -|- C[0:half] -|
  • Approach 2: Chunk-interleaved permutation (#37603) — tested 1p2d, 2p4d
    • conv_states layout: [x | B | C] --> chunk-interleaved transpose [x0 B0 C0 | x1 B1 C1 | ...]
    • Note:
      • B_dim = C_dim;
      • x_dim \neq B_dim in general -->
      • so x chunk size \neq B chunk size -->
      • so need to find chunk size for grouping (e.g., with gcd)
    • permutation before transfer (on P)
    • inverse permutation after transfer (on D)
P's conv (original):  |-------- x --------|- B -|- C -|
                    permute on P (chunk-interleaved)
P's conv (permuted):  |-- x₀ B₀ C₀ --|-- x₁ B₁ C₁ --|
                      └─ D rank 0 ─┘  └─ D rank 1 ─┘
                      1 RDMA read each (contiguous)
                    inverse permute on D
D rank 0:             |--- x₀ ---|- B₀ -|- C₀ -|
  • Apporach 3: 3-read conv state transfer with DS layout (#37635) — tested 1p2d, 2p4d
    • keep conv_states layout: [x | B | C]
    • Use 3 NIXL RDMA reads for conv_states instead of 1. (i.e., 4 RDMA per SSM block instead of 2)
P's conv (DS layout): |-------- x --------|- B -|- C -|
                        ▲          ▲        ▲     ▲
                        │          │        │     │
D rank 0 reads:     read 1 ───┘          read 2  read 3
                   (first half of x)  (first half (first half
                                        of B)      of C)

                   3 RDMA reads, no permutation on either side

Remaining work

  • Comprehensive code review and cleanup.
  • Handle tp_ratio < 0 (P_TP > D_TP, e.g. 2p1d)
  • Homo-TP fast path, opt-out of hetero TP support (if user is certain). Maybe?
  • Performance comparison: permute vs 3-read (vs full conv read).
  • Mixed prefill and decode TP testing (e.g., P_TP=1, D1_TP=2, D2_TP=4)
  • Consider DS layout as default for disagg setups
  • Decide which approach to adopt for upstream

Related

  • #36780 — RFC: NixlConnector hybrid SSM-FA support
  • #36687 — Homogeneous TP support (merged)
  • #37416 — DS conv state layout support

extent analysis

Fix Plan

To address the issue of non-contiguous reads for Mamba conv state with hetero-TP, we will implement Approach 2: Chunk-interleaved permutation.

Here are the steps:

  • Modify the conv_states layout to chunk-interleaved transpose: [x | B | C] --> [x0 B0 C0 | x1 B1 C1 | ...]
  • Find the chunk size for grouping using gcd due to different sizes of x, B, and C
  • Implement permutation before transfer on P and inverse permutation after transfer on D

Example code snippet in Python:

import numpy as np

def chunk_interleaved_permute(conv_states, chunk_size):
    # Split conv_states into chunks
    x, B, C = conv_states
    x_chunks = np.split(x, chunk_size)
    B_chunks = np.split(B, chunk_size)
    C_chunks = np.split(C, chunk_size)
    
    # Interleave chunks
    permuted_conv_states = np.concatenate([np.stack([x_chunk, B_chunk, C_chunk], axis=0) for x_chunk, B_chunk, C_chunk in zip(x_chunks, B_chunks, C_chunks)], axis=1)
    
    return permuted_conv_states

def inverse_permute(permuted_conv_states, chunk_size):
    # Split permuted conv_states into chunks
    chunks = np.split(permuted_conv_states, chunk_size, axis=1)
    
    # Inverse interleave chunks
    x = np.concatenate([chunk[0] for chunk in chunks], axis=0)
    B = np.concatenate([chunk[1] for chunk in chunks], axis=0)
    C = np.concatenate([chunk[2] for chunk in chunks], axis=0)
    
    return x, B, C

# Example usage
conv_states = (np.array([1, 2, 3, 4, 5, 6]), np.array([7, 8, 9, 10, 11, 12]), np.array([13, 14, 15, 16, 17, 18]))
chunk_size = 2
permuted_conv_states = chunk_interleaved_permute(conv_states, chunk_size)
x, B, C = inverse_permute(permuted_conv_states, chunk_size)

Verification

To verify the fix, test the implementation with different TP sizes and conv state layouts. Check that the permutation and inverse permutation are correct and that the RDMA reads are contiguous.

Extra Tips

  • Consider adding a comprehensive code review and cleanup to ensure the implementation is efficient and easy to maintain.
  • Handle the case where tp_ratio < 0 (P_TP > D_TP) and consider adding a fast path for homo-TP.
  • Compare the performance of the different approaches (permute, 3

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [Tracking Issue]: Mamba Heterogeneous TP for NIXL P/D Disaggregation [5 pull requests, 1 participants]