pytorch - ✅(Solved) Fix [RFC] All-to-all permute for Ulysses Parallel [1 pull requests, 3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178066Fetched 2026-04-08 01:12:19
View on GitHub
Comments
3
Participants
3
Timeline
32
Reactions
0
Author
Assignees
Timeline (top)
mentioned ×12subscribed ×12commented ×3cross-referenced ×2

Fix Action

Fixed

PR fix notes

PR #178230: [SymmMem] Add all_to_all_permute for Ulysses-style exchange

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178230
  • #177791

Resolves #178066.

Added torch.distributed._symmetric_memory.all_to_all_permute with explicit scatter_dim and gather_dim:

def all_to_all_permute(
    input: Tensor, out: Tensor, scatter_dim: int, gather_dim: int, *, group: str
)

scatter_dim (int): 0 or 1 — dimension along which the input is partitioned.

gather_dim (int): 0 or 1 — dimension along which received chunks are concatenated in out.

Supported combinations: (1, 0) and (0, 1).

The op avoids a tensor permutation before or after the traditional all-to-all op (which can only send/receive peer data on dim 0). It directly reads peers’ shards at strided offsets, from the NCCL symmetric window and writes a specified output layout, using per-CTA LSA barriers for synchronization.

Accept 2-D inputs and equivalent row-major 3-D shapes where documented.

Wire the CUDA translation unit into the build and extend distributed NCCL tests to cover the new API.

Changed files

  • build_variables.bzl (modified, +1/-0)
  • caffe2/CMakeLists.txt (modified, +1/-0)
  • test/distributed/test_nccl.py (modified, +101/-0)
  • torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp (modified, +2/-0)
  • torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu (modified, +1/-0)
  • torch/csrc/distributed/c10d/symm_mem/nccl_extension.hpp (modified, +11/-0)
  • torch/csrc/distributed/c10d/symm_mem/ops/nccl_all_to_all_permute.cu (added, +365/-0)
  • torch/distributed/_symmetric_memory/__init__.py (modified, +49/-0)

Code Example

def all_to_all_4D(
    input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False
) -> torch.tensor

---

# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
        # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
        input_t = input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous()

        if seq_world_size > 1:
            dist.all_to_all_single(output, input_t, group=group)

---

Per Q/K/V tensor                                                                                                                                                                                                                                                           
                                                            
  Permute (128 MB HBM traffic, ~70% efficiency):

---

All-to-all (64 MB over NVLink unidirectional):

---

Combined Q+K+V
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

What is Ulysses Parallel

UP is an optimization to attention modules in Transformers for handling long sequences, commonly seen in multi- or Omni- modal training or inference. Before entering the attention, each rank holds 1/P of the sequence, but all the head states. For the attention to occur in parallel on different heads, it needs to transform into a state where: each rank holds the full sequence, but 1/P of the head states.

More details of UP can be found in this DeepSpeed blog.

UP is being used in:

  • vLLM-Omni
  • Hugging Face accelerate
  • Megatron-DeepSpeed
  • Arctic plugin to vLLM ("Shift Parallelism")

What is the communication pattern

It is a permutation + all-to-all (can be understood as a "global permutation").

On the input side, we need to cut on the hidden dimension (dim 1), send chunk i i.e. input[:, i, :] to destination i, as its output chunk i on seq dimension (dim 0), i.e. out[i, :].

Example code pointer in vLLM-Omni:

def all_to_all_4D(
    input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False
) -> torch.tensor
<img width="960" height="540" alt="Image" src="https://github.com/user-attachments/assets/5dad8252-37e5-4c48-b297-60b0070ce17a" />

Alternatives

dist.all_to_all_single can only cut input data on dim 0. As a result, today's implementation always involves a permutation of the input tensor first:

        # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
        # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
        input_t = input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous()

        if seq_world_size > 1:
            dist.all_to_all_single(output, input_t, group=group)

Additional context

Typical values:

  • bs=1 (long-sequence training typically uses bs=1)
  • num_total_heads=32, head_dim=128 (LLaMA-2 style)
  • p=8 (sequence parallel degree)
  • global_seq=65536 (64K), so local_seq=8192
  • dtype = bf16 (2 bytes)

Tensor size (one of Q/K/V): 8192 * 1 * 32 * 128 * 2 bytes = 64 MB


The permute operation

  • Reshape: [8192, 1, 32, 128] → [8192, 1, 8, 4, 128] (free, just metadata)
  • Permute (2,0,1,3,4): [8192, 1, 8, 4, 128] → [8, 8192, 1, 4, 128]
  • .contiguous() ← this is the actual cost: forces a full memory copy

The permute itself is free — .contiguous() is what triggers the copy. It reads 64 MB and writes 64 MB = 128 MB HBM traffic.


Key hardware specs

  │                                  │ A100 SXM  │ H100 SXM5 │
  ├──────────────────────────────────┼───────────┼───────────┤                                                                                                                                                                                                               
  │ HBM bandwidth                    │ 2.0 TB/s  │ 3.35 TB/s │
  ├──────────────────────────────────┼───────────┼───────────┤                                                                                                                                                                                                               
  │ NVLink version                   │ 3.0       │ 4.0       │                                                                                                                                                                                                               
  ├──────────────────────────────────┼───────────┼───────────┤
  │ NVLink bandwidth (per direction) │ ~300 GB/s │ ~450 GB/s │                                                                                                                                                                                                               
  ├──────────────────────────────────┼───────────┼───────────┤                                                                                                                                                                                                               
  │ HBM / NVLink ratio               │ 6.7×      │ 7.4×      │
  └──────────────────────────────────┴───────────┴───────────┘

Per Q/K/V tensor

Permute (128 MB HBM traffic, ~70% efficiency):

  ┌──────┬──────────────┬────────┐                                                                                                                                                                                                                                           
  │      │ Effective BW │  Time  │                          
  ├──────┼──────────────┼────────┤
  │ A100 │ ~1.40 TB/s   │ ~91 µs │
  ├──────┼──────────────┼────────┤
  │ H100 │ ~2.35 TB/s   │ ~54 µs │                                                                                                                                                                                                                                           
  └──────┴──────────────┴────────┘

All-to-all (64 MB over NVLink unidirectional):

  ┌──────┬───────────┬─────────┐                                                                                                                                                                                                                                             
  │      │ Bandwidth │  Time   │                            
  ├──────┼───────────┼─────────┤
  │ A100 │ ~300 GB/s │ ~213 µs │
  ├──────┼───────────┼─────────┤
  │ H100 │ ~450 GB/s │ ~142 µs │                                                                                                                                                                                                                                             
  └──────┴───────────┴─────────┘

Combined Q+K+V

  │      │ Permute ×3 │ All-to-all ×3 │ Permute as % of A2A │
  ├──────┼────────────┼───────────────┼─────────────────────┤                                                                                                                                                                                                                
  │ A100 │ ~273 µs    │ ~639 µs       │ ~43%                │
  ├──────┼────────────┼───────────────┼─────────────────────┤
  │ H100 │ ~162 µs    │ ~426 µs       │ ~38%                │
  └──────┴────────────┴───────────────┴─────────────────────┘

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

Fix Plan

To optimize the communication pattern in Ulysses Parallel, we need to reduce the overhead of the permute operation.

Here are the steps:

  • Avoid unnecessary memory copies: Instead of using torch.tensor.contiguous() which forces a full memory copy, consider using torch.tensor.view() or torch.tensor.reshape() to change the tensor's shape without copying the data.
  • Use torch.distributed.all_to_all_single with a custom implementation: Since dist.all_to_all_single can only cut input data on dim 0, we need to implement a custom all-to-all communication function that can handle cutting on other dimensions.

Example code:

import torch
import torch.distributed as dist

def custom_all_to_all(input_tensor, dim, group=None):
    """
    Custom all-to-all communication function that can handle cutting on any dimension.
    
    Args:
    input_tensor (torch.tensor): The input tensor to be communicated.
    dim (int): The dimension to cut the input tensor.
    group (optional): The communication group.
    
    Returns:
    torch.tensor: The output tensor after all-to-all communication.
    """
    # Get the world size and rank
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    # Split the input tensor along the specified dimension
    input_splits = input_tensor.split(1, dim=dim)
    
    # Initialize the output tensor
    output_tensor = torch.empty_like(input_tensor)
    
    # Perform all-to-all communication
    for i in range(world_size):
        # Send the i-th split of the input tensor to the i-th rank
        dist.send(input_splits[i], dst=i)
        
        # Receive the i-th split of the output tensor from the i-th rank
        output_split = torch.empty_like(input_splits[i])
        dist.recv(output_split, src=i)
        
        # Concatenate the received split to the output tensor
        output_tensor = torch.cat((output_tensor, output_split), dim=dim)
    
    return output_tensor

# Example usage:
input_tensor = torch.randn(8, 8192, 1, 4, 128)
output_tensor = custom_all_to_all(input_tensor, dim=2)

Verification

To verify that the fix worked, you can compare the performance of the custom all-to-all communication function with the original implementation using dist.all_to_all_single. You can measure the time it takes to complete the communication and check if the output tensors are correct.

Extra Tips

  • Make sure to handle any potential errors that may occur during the communication, such as timeouts or failed sends/receives.
  • Consider using a more efficient communication algorithm, such as a ring-based all-to-all, to reduce the overhead of the communication.
  • If you're using a GPU, make sure to use the correct CUDA stream to

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING