vllm - 💡(How to fix) Fix [RFC]: Prefill Context Parallel for Qwen3.5 Hybrid Attention [3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37995Fetched 2026-04-08 01:22:05
View on GitHub
Comments
3
Participants
3
Timeline
13
Reactions
0
Timeline (top)
subscribed ×6commented ×3mentioned ×3labeled ×1

This RFC describes the implementation of Prefill Context Parallel (PCP) for Qwen3.5's hybrid attention architecture, which combines full attention and linear attention (GatedDeltaNet) layers. PCP reduces long-context prefill latency by:

  • Full attention layers: Distributing sequence tokens across ranks using zigzag ring attention (cp_size speedup per layer)
  • Linear attention: Splitting batch dimension instead of sequence, eliminating replicated computation and all_gather overhead
  • MoE layers: Spliting tokens dimension

Root Cause

Benefit: Each rank has N/cp_size query tokens, but still attends to all N keys via ring exchange. Computation per rank: (N/cp_size) × N = N²/cp_size → cp_size speedup (not cp_size², because each Q still needs all K/V for causal).

Code Example

Input sequence: [0, 1, 2, 3, 4, 5, 6, 7, 8, ...]

With cp_size = C:

Rank 0: [0, C, C+1, 2C, 2C+1, ...]  (positions where pos % C in {0, C-1, ...})
Rank 1: [1, C-1, C+2, 2C-1, ...]  (positions where pos % C in {1, C-2, ...})
...
Rank r: [r, ...]  (positions where pos % C == r or similar pattern)

---

Rank 0: [0, 3, 4, 7, 8, ...]  (positions 0, 3, 4, 7, ...)
Rank 1: [1, 2, 5, 6, 9, ...]  (positions 1, 2, 5, 6, ...)

---

┌──────────────────────────────────────────────────────────────────┐
Ring Attention (Full Attn)│                                                                    │
Rank r                                        Rank (r+1) mod C│  ┌────────────┐                               ┌────────────┐      │
│  │ Q (local)  │                               │ Q (local)  │      │
│  │ K (local)  │ ─────────── P2P send ───────► │ K (recv)   │      │
│  │ V (local)  │ ─────────── P2P send ────────► │ V (recv)   │      │
│  └────────────┘                               └────────────┘      │
│       │                                              │              │
│       ▼                                              ▼              │
Attention(Q_local, K_all, V_all)        Attention(Q_local, ...)│       │                                              │              │
│       └──────────────── LSE merge ──────────────────┘              │
└──────────────────────────────────────────────────────────────────┘
where C = cp_size (number of PCP ranks)
RAW_BUFFERClick to expand / collapse

Summary

This RFC describes the implementation of Prefill Context Parallel (PCP) for Qwen3.5's hybrid attention architecture, which combines full attention and linear attention (GatedDeltaNet) layers. PCP reduces long-context prefill latency by:

  • Full attention layers: Distributing sequence tokens across ranks using zigzag ring attention (cp_size speedup per layer)
  • Linear attention: Splitting batch dimension instead of sequence, eliminating replicated computation and all_gather overhead
  • MoE layers: Spliting tokens dimension

Motivation

Long-context prefill is a major bottleneck for LLM inference:

  1. Quadratic attention complexity: Full attention is O(N^2) in sequence length, making 256K+ contexts extremely slow
  2. Memory bandwidth : Large KV caches stress memory subsystem

Tensor Parallelism (TP) helps by distributing model weights, but each GPU still processes the full sequence. Context Parallel (CP) addresses this by distributing the sequence itself across GPUs.

Challenge: Hybrid Attention

Qwen3.5 uses a hybrid architecture:

  • 15 full attention layers: Standard transformer attention
  • 45 linear attention layers: GatedDeltaNet (GDN) with recurrent state

Linear attention has a key constraint: the recurrent state at position is depends on ALL previous tokens. This means naive token splitting breaks correctness for linear attention layers.

Design

Architecture Overview

The complete picture showing data distribution and communication across all layer types:

<img width="970" height="824" alt="Image" src="https://github.com/user-attachments/assets/efd4ab72-099d-46b9-8aa5-7d6104a2dd5a" />

Token Distribution

PCP uses zigzag distribution to balance tokens across ranks:

Input sequence: [0, 1, 2, 3, 4, 5, 6, 7, 8, ...]

With cp_size = C:

Rank 0: [0, C, C+1, 2C, 2C+1, ...]  (positions where pos % C in {0, C-1, ...})
Rank 1: [1, C-1, C+2, 2C-1, ...]  (positions where pos % C in {1, C-2, ...})
...
Rank r: [r, ...]  (positions where pos % C == r or similar pattern)

This interleaving ensures each rank has tokens from all parts of the sequence, improving load balance for causal attention.

For cp_size = 2 specifically:

Rank 0: [0, 3, 4, 7, 8, ...]  (positions 0, 3, 4, 7, ...)
Rank 1: [1, 2, 5, 6, 9, ...]  (positions 1, 2, 5, 6, ...)

Layer-Type Handling

Full Attention Layers (Ring Attention)

Full attention layers use zigzag ring attention:

┌──────────────────────────────────────────────────────────────────┐
│                     Ring Attention (Full Attn)                     │
│                                                                    │
│  Rank r                                        Rank (r+1) mod C   │
│  ┌────────────┐                               ┌────────────┐      │
│  │ Q (local)  │                               │ Q (local)  │      │
│  │ K (local)  │ ─────────── P2P send ───────► │ K (recv)   │      │
│  │ V (local)  │ ─────────── P2P send ────────► │ V (recv)   │      │
│  └────────────┘                               └────────────┘      │
│       │                                              │              │
│       ▼                                              ▼              │
│  Attention(Q_local, K_all, V_all)        Attention(Q_local, ...) │
│       │                                              │              │
│       └──────────────── LSE merge ──────────────────┘              │
└──────────────────────────────────────────────────────────────────┘
where C = cp_size (number of PCP ranks)

Each rank:

  1. Computes local Q @ local K/V
  2. Sends K/V to next rank via P2P
  3. Receives K/V from previous rank
  4. Accumulates attention outputs with LSE merge
  5. Repeats for cp_size rounds

Benefit: Each rank has N/cp_size query tokens, but still attends to all N keys via ring exchange. Computation per rank: (N/cp_size) × N = N²/cp_size → cp_size speedup (not cp_size², because each Q still needs all K/V for causal).

Linear Attention Layers (Batch Split)

Linear attention requires the full sequence per item due to recurrent state dependency. We split by batch dimension instead of sequence

  • Linear attention has no cross-item dependency: each item's recurrent state only depends on that item's history
  • Each rank has B/C items, each with FULL sequence S

Key insight: Split by batch, not sequence. This gives each rank the full sequence it needs without any all_gather communication.

Work Plan

Each item below is scoped for a separate pull request.

PRItemDescription
1Ring attention on full-attention layersIntegrate zigzag_ring_prefill_kv into full-attention layers; handle zigzag token distribution, P2P K/V exchange, and LSE merge (flash_attn.py, ring.py)
2Reshuffle callsImplement all-to-all activation reshuffle between seq-split and batch-split at layer group boundaries
3Decode pathImplement per-step LSE-weighted merge across PCP ranks for full-attn layers during decode (flash_attn.py)
4Chunked PrefillSupport chunked prefill with PCP, including handling cached KV attention and ring attention for current chunk
5**Support prefix cache **Enable prefix cache when PCP is enabled

extent analysis

Fix Plan

To implement Prefill Context Parallel (PCP) for Qwen3.5's hybrid attention architecture, follow these steps:

Step 1: Ring Attention on Full-Attention Layers

  • Integrate zigzag_ring_prefill_kv into full-attention layers.
  • Handle zigzag token distribution, P2P K/V exchange, and LSE merge in flash_attn.py and ring.py.
  • Example code:
import torch

def zigzag_ring_prefill_kv(q, k, v, cp_size):
    # Zigzag token distribution
    q_dist = torch.chunk(q, cp_size)
    k_dist = torch.chunk(k, cp_size)
    v_dist = torch.chunk(v, cp_size)

    # P2P K/V exchange and LSE merge
    attention_outputs = []
    for i in range(cp_size):
        q_local = q_dist[i]
        k_recv = k_dist[(i + 1) % cp_size]
        v_recv = v_dist[(i + 1) % cp_size]
        attention_output = torch.matmul(q_local, k_recv.T) * v_recv
        attention_outputs.append(attention_output)

    # LSE merge
    attention_output = torch.stack(attention_outputs).sum(dim=0)
    return attention_output

Step 2: Reshuffle Calls

  • Implement all-to-all activation reshuffle between seq-split and batch-split at layer group boundaries.
  • Example code:
import torch.distributed as dist

def reshuffle_activations(activations, cp_size):
    # All-to-all reshuffle
    reshuffled_activations = []
    for i in range(cp_size):
        send_buffer = activations[i::cp_size]
        recv_buffer = torch.empty_like(send_buffer)
        dist.all_to_all_single(send_buffer, recv_buffer)
        reshuffled_activations.append(recv_buffer)

    # Concatenate reshuffled activations
    reshuffled_activations = torch.cat(reshuffled_activations)
    return reshuffled_activations

Step 3: Decode Path

  • Implement per-step LSE-weighted merge across PCP ranks for full-attn layers during decode.
  • Example code:
def lse_weighted_merge(attention_outputs, weights):
    # LSE weighted merge
    merged_attention_output = torch.stack(attention_outputs).sum(dim=0) * weights
    return merged_attention_output

Step 4: Chunked Prefill

  • Support chunked prefill with PCP, including handling cached KV attention and ring attention for current chunk.
  • Example code:
def chunked_prefill(q, k, v, cp_size, chunk_size):
    # Chunked prefill
    chunks = torch.chunk(q, chunk_size)
    attention_outputs = []
    for chunk in chunks:
        attention_output = zigzag_ring_prefill_kv(chunk

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING