vllm - ✅(Solved) Fix [Bug][NIXL]: TRITON_ATTN ignores `VLLM_KV_CACHE_LAYOUT=HND`, breaks heterogeneous TP with NIXL [3 pull requests, 3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37703Fetched 2026-04-08 01:08:47
View on GitHub
Comments
3
Participants
3
Timeline
11
Reactions
1
Author
Timeline (top)
commented ×3cross-referenced ×3closed ×1labeled ×1

TRITON_ATTN's get_kv_cache_stride_order always returns NHD (identity) stride order, ignoring VLLM_KV_CACHE_LAYOUT=HND. This breaks heterogeneous TP disaggregated serving (P_TP != D_TP) via the NIXL connector.

The NIXL connector sets VLLM_KV_CACHE_LAYOUT=HND and computes a byte rank_offset to split KV heads across D-side TP ranks. This offset assumes heads are physically contiguous (HND layout: [num_heads, block_size, head_dim]). Because TRITON_ATTN keeps the NHD layout ([block_size, num_heads, head_dim]), the offset splits along the token dimension instead of the head dimension, causing each D-rank to read corrupted KV data.

FLASH_ATTN and FlashInfer both respect VLLM_KV_CACHE_LAYOUT and are unaffected.

Root Cause

triton_attn.py get_kv_cache_stride_order (before fix):

# Always returns NHD identity — ignores VLLM_KV_CACHE_LAYOUT
return (0, 1, 2, 3, 4)

The write kernel in triton_reshape_and_cache_flash.py also uses flat indexing (block_stride + page_stride + tile_pos) which only works for contiguous NHD layout.

Fix Action

Fix / Workaround

============================== CPU Info

Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 160 On-line CPU(s) list: 0-159 Vendor ID: GenuineIntel Model name: Intel Xeon Processor (SapphireRapids) CPU family: 6 Model: 143 Thread(s) per core: 2 Core(s) per socket: 40 Socket(s): 2 Stepping: 4 BogoMIPS: 4200.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 5 MiB (160 instances) L1i cache: 5 MiB (160 instances) L2 cache: 320 MiB (80 instances) L3 cache: 32 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-79 NUMA node1 CPU(s): 80-159 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

PR fix notes

PR #35444: [Bugfix] Fix Triton attention layout when used in combination with the NIXL connector

Description (problem / solution / changelog)

Purpose

Fixes a bug when the NIXL connector is used in combination with the triton attention backend.

The NIXL connector requests the HND layout in get_required_kvcache_layout() for byte-offset head splitting during heterogeneous TP transfers. But TritonAttentionBackend.get_kv_cache_stride_order() returns NHD.

This causes NIXL's splitting along heads to grab garbage data on ROCm, where Triton is the default backend.

Test Plan

  • Added a simple unit test catching this invariant across different attention backends. And used
  • Run gsm8k againt a TP=1 prefill, TP=2 decode deployment of RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic

Test Result

main:

local-completions ({'model': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'base_url': 'http://infra-disagg-inference-gateway-istio.tms.svc.cluster.local/v1/completions', 'num_concurrent': 2000, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 10, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|    10|exact_match|↑  |    0|±  |     0|
|     |       |strict-match    |    10|exact_match|↑  |    0|±  |     0|

This PR:

local-completions ({'model': 'RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic', 'base_url': 'http://infra-disagg-inference-gateway-istio.tms.svc.cluster.local/v1/completions', 'num_concurrent': 2000, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 10, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|    10|exact_match|↑  |0.9386|±  |0.0066|
|     |       |strict-match    |    10|exact_match|↑  |0.8901|±  |0.0086|

Changed files

  • tests/v1/kv_connector/unit/test_hetero_tp_layout_bug.py (added, +169/-0)
  • vllm/v1/attention/backends/triton_attn.py (modified, +13/-4)
  • vllm/v1/attention/ops/triton_reshape_and_cache_flash.py (modified, +17/-3)

PR #37940: [NIXL][BUG] Fix Triton heterogeneous TP

Description (problem / solution / changelog)

co-authored with @ZhanqiuHu

Purpose

  • Fix Triton Attn Heterogeneous TP Disagg: #37703
  • Also fixes Gemma with Heterogeneous TP bug, also caused by Triton Backend: #37333
  • Enable cross-layer TP disagg for Triton, which now has the same KV cache layout as FlashInfer

Test Plan

In tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh, replace tp_configs with the following:

Fixed GEMMA tests:

"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=google/gemma-3-4b-it PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192"

Fixed Triton backend test

"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 VLLM_SERVE_EXTRA_ARGS=--attention-backend,TRITON_ATTN"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=1 VLLM_SERVE_EXTRA_ARGS=--attention-backend,TRITON_ATTN"

Run cd tests && v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh.

Test Result

All tests now pass. Triton backend was tested with CROSS_LAYERS_BLOCKS=0 and CROSS_LAYERS_BLOCKS=1.

cc @NickLucche


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • tests/v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh (modified, +7/-0)
  • tests/v1/kv_connector/unit/test_nixl_connector.py (modified, +17/-15)
  • vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py (modified, +16/-0)
  • vllm/v1/attention/backends/triton_attn.py (modified, +13/-4)
  • vllm/v1/attention/ops/triton_reshape_and_cache_flash.py (modified, +9/-3)

PR #37943: [NIXL] Strengthen TpKVTopology validation

Description (problem / solution / changelog)

Purpose

Strengthen TpKVTopology config validation. Previous validation contained dead code and did not take layer dimension into account when validating cross-layer layouts. It now throws an error when an unexpected KV cache shape is encountered instead of continuing.

TpKVTopology Config Validation

Computes three boolean properties:

  • split_k_and_v
    • K and V are in seperate regions, i.e. tensors shape (2, num_blocks, ...)
  • is_kv_layout_blocks_first
    • K and V are interleaved within each block, i.e. tensor shape (num_blocks, 2, ...). A each region gets two descriptors (one for K one for V).
  • cross_layers_blocks
    • All layers share a single tensor with an extra num_layers dimension instead of one tensor per layer.

split_k_and_v and is_kv_layout_blocks_first are mutually exclusive.

<details> <summary>Full backend / layout flag matrix</summary>

M = num_blocks, N = block_size, H = num_kv_heads, D = head_dim, L = num_layers.

BackendLayoutPhysical shapeShould: split_k_and_vShould: is_kv_layout_blocks_firstnum_regions / layer
FLASH_ATTNNHD(2, M, N, H, D)TrueFalse2
FLASH_ATTNHND(2, M, H, N, D)TrueFalse2
FLASH_ATTNCross Layer NHD(M, L, 2, N, H, D)FalseFalse1
FLASH_ATTNCross Layer HND(M, H, L, 2, N, D)FalseFalse1
TRITONNHD(M, 2, N, H, D)FalseTrue2
TRITONHND(M, 2, H, N, D)FalseTrue2
TRITONCross Layer NHD(M, L, 2, N, H, D)FalseFalse1
TRITONCross Layer HND(M, 2, H, L, N, D)FalseTrue2
MLA(M, N, D)FalseFalse1
MAMBA (hybrid)blocks-first (forced)FalseTrue2
</details>

The flag selection should inspect the first two physical dimensions of the KV cache tensor:

  • (2, num_blocks, ...)split_k_and_v = True. K and V are separate contiguous regions.
  • (num_blocks, 2, ...) or Mamba → is_kv_layout_blocks_first = True. K and V (or conv and SSM) are interleaved within each block.
  • (num_blocks, ...) with no size-2 KV dimension (e.g. MLA) → both False.
  • Other layouts are not supported

Future Work

Heterogeneous TP: Requires N, D and adjacent H dimensions to be contiguous. Not only does this mean only HND backends are supported, it also expects the layout of each region to be (H, ..., N, D) (not (..., H, N, D)). This is True for all current HND backend but is not explictly validated.

Cross Layer NHD: Was not enabled, possibly due to the previous TpKVTopology not handling it correctly. It's now fixed and would work with both Flash Attn and Triton backends.

post_process_device_kv_on_receive: Only assumes [num_blocks, n_kv_head, block_size, head_dim] layout, so it breaks when is_kv_layout_blocks_first = True or when cross_layer = True

Test Plan

Cross layer:

bash v1/kv_connector/nixl_integration/run_accuracy_test.sh --attention-backend FLASH_ATTN --enable-cross-layers
bash v1/kv_connector/nixl_integration/run_accuracy_test.sh --attention-backend TRITON --enable-cross-layers

Test Result

Cross layer tests pass for HDN. Also pass for NHD if I remove the guard in nixl_connector.py::L336 disallowing NHD with cross-layer.

cc @NickLucche


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • vllm/distributed/kv_transfer/kv_connector/utils.py (modified, +47/-32)
  • vllm/v1/worker/kv_connector_model_runner_mixin.py (modified, +2/-2)

Code Example

Collecting environment information...
==============================
        System Info
==============================
OS                           : Red Hat Enterprise Linux 9.6 (Plow) (x86_64)
GCC version                  : (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version                : Could not collect
CMake version                : version 3.26.5
Libc version                 : glibc-2.34

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cu130
Is debug build               : False
CUDA used to build PyTorch   : 13.0
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.9 (main, Aug 14 2025, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-5)] (64-bit runtime)
Python platform              : Linux-5.14.0-570.73.1.el9_6.x86_64-x86_64-with-glibc2.34

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 13.0.88
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version        : 590.48.01
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 57 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  160
On-line CPU(s) list:                     0-159
Vendor ID:                               GenuineIntel
Model name:                              Intel Xeon Processor (SapphireRapids)
CPU family:                              6
Model:                                   143
Thread(s) per core:                      2
Core(s) per socket:                      40
Socket(s):                               2
Stepping:                                4
BogoMIPS:                                4200.00
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Virtualization:                          VT-x
Hypervisor vendor:                       KVM
Virtualization type:                     full
L1d cache:                               5 MiB (160 instances)
L1i cache:                               5 MiB (160 instances)
L2 cache:                                320 MiB (80 instances)
L3 cache:                                32 MiB (2 instances)
NUMA node(s):                            2
NUMA node0 CPU(s):                       0-79
NUMA node1 CPU(s):                       80-159
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Unknown: No mitigations
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.6.6
[pip3] numpy==2.2.6
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-cu13==9.15.1.9
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile==1.15.1.6
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-cutlass-dsl==4.4.2
[pip3] nvidia-cutlass-dsl-libs-base==4.4.2
[pip3] nvidia-ml-py==13.590.48
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.4.5
[pip3] nvidia-nvshmem-cu13==3.4.5
[pip3] nvidia-nvtx==13.0.85
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+cu130
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torchaudio==2.10.0+cu130
[pip3] torchvision==0.25.0+cu130
[pip3] transformers==4.57.6
[pip3] triton==3.6.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.17.2rc1.dev139+gebd77f59d (git sha: ebd77f59d)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-79    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    80-159  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    80-159  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    80-159  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      80-159  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

==============================
     Environment Variables
==============================
LD_LIBRARY_PATH=:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64
CUDA_HOME=/usr/local/cuda-13.0
CUDA_HOME=/usr/local/cuda-13.0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_ZhanqiuHu

---

# Qwen3-0.6B with forced TRITON_ATTN, PTP=1 DTP=2
vllm serve Qwen/Qwen3-0.6B \
  --attention-backend TRITON_ATTN \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
  --tensor-parallel-size 2  # decoder side
# With a PTP=1 prefiller and proxy routing to this decoder

---

Prompt:     "The capital of France is"
Completion: "is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is"

---

# Always returns NHD identity — ignores VLLM_KV_CACHE_LAYOUT
return (0, 1, 2, 3, 4)

---

from vllm.v1.attention.backends.utils import get_kv_cache_layout

@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD" and include_num_layers_dimension:
        return (1, 0, 2, 3, 4, 5)
    elif cache_layout == "NHD":
        return (0, 1, 2, 3, 4)
    elif cache_layout == "HND" and include_num_layers_dimension:
        return (1, 2, 4, 0, 3, 5)
    elif cache_layout == "HND":
        return (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout: {cache_layout}")

---

# Before (assumes contiguous NHD):
tgt_base = block_idx * block_stride + block_offset * page_stride
tgt_idx_k = tgt_base + tile_pos
tgt_idx_v = tgt_base + tile_pos

# After (works for both NHD and HND via strides):
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
tgt_idx_k = (
    block_idx * block_stride
    + block_offset * page_stride
    + cur_head * head_stride
    + cur_dim
)
tgt_idx_v = tgt_idx_k
RAW_BUFFERClick to expand / collapse

Fixes https://github.com/vllm-project/vllm/issues/37333.

Your current environment

<details> <summary>The output of <code>python collect_env.py</code></summary>
Collecting environment information...
==============================
        System Info
==============================
OS                           : Red Hat Enterprise Linux 9.6 (Plow) (x86_64)
GCC version                  : (GCC) 11.5.0 20240719 (Red Hat 11.5.0-5)
Clang version                : Could not collect
CMake version                : version 3.26.5
Libc version                 : glibc-2.34

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cu130
Is debug build               : False
CUDA used to build PyTorch   : 13.0
ROCM used to build PyTorch   : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.9 (main, Aug 14 2025, 00:00:00) [GCC 11.5.0 20240719 (Red Hat 11.5.0-5)] (64-bit runtime)
Python platform              : Linux-5.14.0-570.73.1.el9_6.x86_64-x86_64-with-glibc2.34

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 13.0.88
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version        : 590.48.01
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 57 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  160
On-line CPU(s) list:                     0-159
Vendor ID:                               GenuineIntel
Model name:                              Intel Xeon Processor (SapphireRapids)
CPU family:                              6
Model:                                   143
Thread(s) per core:                      2
Core(s) per socket:                      40
Socket(s):                               2
Stepping:                                4
BogoMIPS:                                4200.00
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Virtualization:                          VT-x
Hypervisor vendor:                       KVM
Virtualization type:                     full
L1d cache:                               5 MiB (160 instances)
L1i cache:                               5 MiB (160 instances)
L2 cache:                                320 MiB (80 instances)
L3 cache:                                32 MiB (2 instances)
NUMA node(s):                            2
NUMA node0 CPU(s):                       0-79
NUMA node1 CPU(s):                       80-159
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Unknown: No mitigations
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Not affected

==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.6.6
[pip3] numpy==2.2.6
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-cu13==9.15.1.9
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile==1.15.1.6
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-cutlass-dsl==4.4.2
[pip3] nvidia-cutlass-dsl-libs-base==4.4.2
[pip3] nvidia-ml-py==13.590.48
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.4.5
[pip3] nvidia-nvshmem-cu13==3.4.5
[pip3] nvidia-nvtx==13.0.85
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+cu130
[pip3] torch_c_dlpack_ext==0.1.5
[pip3] torchaudio==2.10.0+cu130
[pip3] torchvision==0.25.0+cu130
[pip3] transformers==4.57.6
[pip3] triton==3.6.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.17.2rc1.dev139+gebd77f59d (git sha: ebd77f59d)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-79    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-79    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    80-159  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    80-159  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    80-159  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      80-159  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

==============================
     Environment Variables
==============================
LD_LIBRARY_PATH=:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64:/usr/local/cuda-13.0/lib64
CUDA_HOME=/usr/local/cuda-13.0
CUDA_HOME=/usr/local/cuda-13.0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_ZhanqiuHu
</details>

🐛 Describe the bug

Description

TRITON_ATTN's get_kv_cache_stride_order always returns NHD (identity) stride order, ignoring VLLM_KV_CACHE_LAYOUT=HND. This breaks heterogeneous TP disaggregated serving (P_TP != D_TP) via the NIXL connector.

The NIXL connector sets VLLM_KV_CACHE_LAYOUT=HND and computes a byte rank_offset to split KV heads across D-side TP ranks. This offset assumes heads are physically contiguous (HND layout: [num_heads, block_size, head_dim]). Because TRITON_ATTN keeps the NHD layout ([block_size, num_heads, head_dim]), the offset splits along the token dimension instead of the head dimension, causing each D-rank to read corrupted KV data.

FLASH_ATTN and FlashInfer both respect VLLM_KV_CACHE_LAYOUT and are unaffected.

Affected Models

Any model using TRITON_ATTN with heterogeneous TP + NIXL.

Reproduction

# Qwen3-0.6B with forced TRITON_ATTN, PTP=1 DTP=2
vllm serve Qwen/Qwen3-0.6B \
  --attention-backend TRITON_ATTN \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
  --tensor-parallel-size 2  # decoder side
# With a PTP=1 prefiller and proxy routing to this decoder

Quick sanity — before fix:

Prompt:     "The capital of France is"
Completion: "is is is is is is is is is is is is is is is is is is is is is is is is is is is is is is"

Root Cause

triton_attn.py get_kv_cache_stride_order (before fix):

# Always returns NHD identity — ignores VLLM_KV_CACHE_LAYOUT
return (0, 1, 2, 3, 4)

The write kernel in triton_reshape_and_cache_flash.py also uses flat indexing (block_stride + page_stride + tile_pos) which only works for contiguous NHD layout.

Proposed Fix (2 files, ~15 lines changed)

1. vllm/v1/attention/backends/triton_attn.py

Read get_kv_cache_layout() and return the correct stride order, matching FLASH_ATTN's implementation:

from vllm.v1.attention.backends.utils import get_kv_cache_layout

@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD" and include_num_layers_dimension:
        return (1, 0, 2, 3, 4, 5)
    elif cache_layout == "NHD":
        return (0, 1, 2, 3, 4)
    elif cache_layout == "HND" and include_num_layers_dimension:
        return (1, 2, 4, 0, 3, 5)
    elif cache_layout == "HND":
        return (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout: {cache_layout}")

2. vllm/v1/attention/ops/triton_reshape_and_cache_flash.py

Replace flat indexing in the 4D write path with stride-based indexing:

# Before (assumes contiguous NHD):
tgt_base = block_idx * block_stride + block_offset * page_stride
tgt_idx_k = tgt_base + tile_pos
tgt_idx_v = tgt_base + tile_pos

# After (works for both NHD and HND via strides):
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
tgt_idx_k = (
    block_idx * block_stride
    + block_offset * page_stride
    + cur_head * head_stride
    + cur_dim
)
tgt_idx_v = tgt_idx_k

3. Optionally add an assertion to enforce the layout

Results After Fix

Qwen3-0.6B, --attention-backend TRITON_ATTN, PTP=1 DTP=2, gsm8k 5-shot:

CheckValueStatus
Quick sanity"Paris. The capital of France is also the capital of the Republic of France..."PASS
Accuracy (strict-match)0.4155 (expected ~0.41)PASS
External KV cache hit rate100.00%PASS
Transfer errors in logs0PASS

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

Fix Plan

To fix the issue with TRITON_ATTN's get_kv_cache_stride_order always returning NHD (identity) stride order, ignoring VLLM_KV_CACHE_LAYOUT=HND, follow these steps:

  • Modify vllm/v1/attention/backends/triton_attn.py to read get_kv_cache_layout() and return the correct stride order.
  • Update vllm/v1/attention/ops/triton_reshape_and_cache_flash.py to use stride-based indexing instead of flat indexing.

Code Changes

1. vllm/v1/attention/backends/triton_attn.py

from vllm.v1.attention.backends.utils import get_kv_cache_layout

@staticmethod
def get_kv_cache_stride_order(
    include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
    cache_layout = get_kv_cache_layout()
    if cache_layout == "NHD" and include_num_layers_dimension:
        return (1, 0, 2, 3, 4, 5)
    elif cache_layout == "NHD":
        return (0, 1, 2, 3, 4)
    elif cache_layout == "HND" and include_num_layers_dimension:
        return (1, 2, 4, 0, 3, 5)
    elif cache_layout == "HND":
        return (0, 1, 3, 2, 4)
    else:
        raise ValueError(f"Unknown cache layout: {cache_layout}")

2. vllm/v1/attention/ops/triton_reshape_and_cache_flash.py

# Replace flat indexing with stride-based indexing
cur_head = tile_pos // head_size
cur_dim = tile_pos % head_size
tgt_idx_k = (
    block_idx * block_stride
    + block_offset * page_stride
    + cur_head * head_stride
    + cur_dim
)
tgt_idx_v = tgt_idx_k

Verification

To verify that the fix worked, run the following command:

vllm serve Qwen/Qwen3-0.6B \
  --attention-backend TRITON_ATTN \
  --kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}' \
  --tensor-parallel-size 2

Check the output for the correct completion and accuracy.

Extra Tips

  • Make sure to test the fix with different models and configurations to ensure it works as expected.
  • Consider adding an assertion to enforce the layout and prevent similar issues in the future.
  • If you encounter any issues or have questions, don't hesitate to ask the chatbot or search for relevant issues in the documentation.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING