pytorch - ✅(Solved) Fix [distributed] CUDNN SDPA backend + CP stride failure [1 pull requests, 6 comments, 6 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#176915Fetched 2026-04-08 00:23:52
View on GitHub
Comments
6
Participants
6
Timeline
53
Reactions
0
Author
Assignees
Timeline (top)
mentioned ×16subscribed ×16commented ×6labeled ×3

Error Message

File "/data/users/pianpwk/pytorch/torch/distributed/elastic/multiprocessing/errors/init.py", line 367, in wrapper return f(*args, **kwargs) File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 819, in train self.train_step(data_iterator) File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 729, in train_step loss = self.forward_backward_step( File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 687, in forward_backward_step loss.backward() File "/data/users/pianpwk/pytorch/torch/_tensor.py", line 631, in backward torch.autograd.backward( File "/data/users/pianpwk/pytorch/torch/autograd/init.py", line 379, in backward _engine_run_backward( File "/data/users/pianpwk/pytorch/torch/autograd/graph.py", line 877, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 926, in _sdpa_handler local_results = call_maps[op_call]( File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 873, in _scaled_dot_product_ring_cudnn_attention_backward return _templated_ring_attention_backward( File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/context_parallel/attention.py", line 574, in templated_ring_attention_backward grad_query, grad_key, grad_value, *rest = op( File "/data/users/pianpwk/pytorch/torch/_ops.py", line 871, in call return self.op(*args, **kwargs) RuntimeError: same_strides(o, dO) INTERNAL ASSERT FAILED at "/data/users/pianpwk/pytorch/aten/src/ATen/native/cudnn/MHA.cpp":1628, please report a bug to PyTorch. cuDNN SDPA expected grad_output.strides() == output.strides(), the previous step probably failed to materialize a grad_output with matching strides...

============================================================

Root Cause

Root cause: In _templated_ring_attention_backward(), the merged out from SDPAMerger has non-standard strides (due to arithmetic ops + chunk/cat), while grad_out has standard contiguous strides from autograd. cuDNN backward asserts these must match. The fix is to add .contiguous() for out and dout before passing to the cuDNN backward op, similar to what's already done for logsumexp on line 568 of _attention.py. """

Fix Action

Fix / Workaround

import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh from torch.distributed.tensor.experimental._context_parallel import ( _context_parallel_shard, _ContextParallel, _enable_context_parallel_dispatcher, _HeadTailLoadBalancer, ) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend

_enable_context_parallel_dispatcher()

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 168 On-line CPU(s) list: 0-167 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz CPU family: 6 Model: 85 Thread(s) per core: 2 Core(s) per socket: 21 Socket(s): 4 Stepping: 11 BogoMIPS: 3591.76 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat vnmi umip pku ospke avx512_vnni md_clear flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 5.3 MiB (168 instances) L1i cache: 5.3 MiB (168 instances) L2 cache: 336 MiB (84 instances) L3 cache: 64 MiB (4 instances) NUMA node(s): 4 NUMA node0 CPU(s): 0-41 NUMA node1 CPU(s): 42-83 NUMA node2 CPU(s): 84-125 NUMA node3 CPU(s): 126-167 Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled

PR fix notes

PR #2480: Qwen3 TP + CP with local_map region

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #2480

Fixes https://github.com/pytorch/torchtitan/issues/2446 by using local_map region for CP

Changed files

  • tests/integration_tests/models.py (modified, +13/-0)
  • torchtitan/models/qwen3/parallelize.py (modified, +56/-0)

Code Example

File "/data/users/pianpwk/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 367, in wrapper
      return f(*args, **kwargs)
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 819, in train
      self.train_step(data_iterator)
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 729, in train_step
      loss = self.forward_backward_step(
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 687, in forward_backward_step
      loss.backward()
    File "/data/users/pianpwk/pytorch/torch/_tensor.py", line 631, in backward
      torch.autograd.backward(
    File "/data/users/pianpwk/pytorch/torch/autograd/__init__.py", line 379, in backward
      _engine_run_backward(
    File "/data/users/pianpwk/pytorch/torch/autograd/graph.py", line 877, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 926, in _sdpa_handler
      local_results = call_maps[op_call](
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 873, in _scaled_dot_product_ring_cudnn_attention_backward
      return _templated_ring_attention_backward(
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 574, in _templated_ring_attention_backward
      grad_query_, grad_key_, grad_value_, *rest = op(
    File "/data/users/pianpwk/pytorch/torch/_ops.py", line 871, in __call__
      return self._op(*args, **kwargs)
  RuntimeError: same_strides(o, dO_) INTERNAL ASSERT FAILED at "/data/users/pianpwk/pytorch/aten/src/ATen/native/cudnn/MHA.cpp":1628, please report a bug to PyTorch. cuDNN SDPA expected grad_output.strides() == output.strides(), the previous step probably failed to materialize a grad_output with matching strides...
  
============================================================

---

"""
Minimal repro: cuDNN SDPA backward stride mismatch with Context Parallel.

This is a pre-existing PyTorch bug where cuDNN SDPA backward requires
grad_output.strides() == output.strides(), but CP ring attention backward
produces mismatched strides due to _SDPAMerger arithmetic + torch.chunk/cat.

Run with 2 GPUs (must be H100 or other GPU that supports cuDNN SDPA):

    torchrun --nproc_per_node=2 repro_cudnn_cp_stride.py

Expected error:
    RuntimeError: same_strides(o, dO_) INTERNAL ASSERT FAILED at
    .../aten/src/ATen/native/cudnn/MHA.cpp:1628

Root cause:
    In _templated_ring_attention_backward(), the merged `out` from
    _SDPAMerger has non-standard strides (due to arithmetic ops + chunk/cat),
    while `grad_out` has standard contiguous strides from autograd.
    cuDNN backward asserts these must match. The fix is to add .contiguous()
    for out_ and dout before passing to the cuDNN backward op, similar to
    what's already done for logsumexp on line 568 of _attention.py.
"""

import os

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental._context_parallel import (
    _context_parallel_shard,
    _ContextParallel,
    _enable_context_parallel_dispatcher,
    _HeadTailLoadBalancer,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend


class SDPAModule(nn.Module):
    """Thin wrapper so we can parallelize_module with _ContextParallel."""

    def forward(self, q, k, v, **kwargs):
        return F.scaled_dot_product_attention(q, k, v, **kwargs)


def main():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    device_mesh = init_device_mesh("cuda", (world_size,))
    device = f"cuda:{rank}"
    dtype = torch.bfloat16

    # Attention dimensions
    bs = 2
    nheads = 8
    seq_len = 1024  # per-rank seq_len after sharding
    head_dim = 32
    seq_dim = 2  # (bs, nheads, seq, head_dim)

    # Create identical tensors on all ranks
    torch.manual_seed(42)
    full_seq = seq_len * world_size
    q = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)
    k = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)
    v = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)

    with torch.no_grad():
        dist.broadcast(q, src=0)
        dist.broadcast(k, src=0)
        dist.broadcast(v, src=0)

    # Set up CP via parallelize_module (same pattern as torchtitan)
    cp_plan = _ContextParallel(
        seq_dim=seq_dim,
        attention_type=_ContextParallel.AttentionType.SDPA,
    )
    attention = SDPAModule()
    attention = parallelize_module(attention, device_mesh, cp_plan)

    # Shard Q, K, V along sequence dim with load balancing
    load_balancer = _HeadTailLoadBalancer(full_seq, world_size, device)
    cp_q, cp_k, cp_v = _context_parallel_shard(
        device_mesh, (q, k, v), (seq_dim,) * 3, load_balancer=load_balancer
    )
    cp_q.requires_grad_(True)
    cp_k.requires_grad_(True)
    cp_v.requires_grad_(True)

    _enable_context_parallel_dispatcher()

    # Force cuDNN backend — this triggers the stride mismatch on backward
    if rank == 0:
        print(f"Running CP ring attention with cuDNN SDPA backend...")
        print(f"  world_size={world_size}, bs={bs}, nheads={nheads}, "
              f"seq_len={full_seq} (per-rank: {seq_len}), head_dim={head_dim}")

    with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
        out = attention(cp_q, cp_k, cp_v, is_causal=True)
        # Backward triggers the stride mismatch assertion in cuDNN
        out.sum().backward()

    if rank == 0:
        print("SUCCESS — no stride mismatch (bug may be fixed!)")


if __name__ == "__main__":
    main()

---

Collecting environment information...
PyTorch version: 2.12.0a0+git9bd00c2
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-14)
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.34

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk15_hardened_2630_gf27365f948db-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA PG509-210
GPU 1: NVIDIA PG509-210
GPU 2: NVIDIA PG509-210
GPU 3: NVIDIA PG509-210
GPU 4: NVIDIA PG509-210
GPU 5: NVIDIA PG509-210
GPU 6: NVIDIA PG509-210
GPU 7: NVIDIA PG509-210

Nvidia driver version: 550.90.07
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.9.1.0
/usr/lib64/libcudnn_adv.so.9.1.0
/usr/lib64/libcudnn_cnn.so.9.1.0
/usr/lib64/libcudnn_engines_precompiled.so.9.1.0
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib64/libcudnn_graph.so.9.1.0
/usr/lib64/libcudnn_heuristic.so.9.1.0
/usr/lib64/libcudnn_ops.so.9.1.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             168
On-line CPU(s) list:                0-167
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 21
Socket(s):                          4
Stepping:                           11
BogoMIPS:                           3591.76
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat vnmi umip pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          5.3 MiB (168 instances)
L1i cache:                          5.3 MiB (168 instances)
L2 cache:                           336 MiB (84 instances)
L3 cache:                           64 MiB (4 instances)
NUMA node(s):                       4
NUMA node0 CPU(s):                  0-41
NUMA node1 CPU(s):                  42-83
NUMA node2 CPU(s):                  84-125
NUMA node3 CPU(s):                  126-167
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:           Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] botorch==0.14.0
[pip3] executorch==0.8.0a0+a27dd42
[pip3] flake8==7.3.0
[pip3] flake8-bugbear==24.12.12
[pip3] flake8-comprehensions==3.16.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==2024.24.12
[pip3] flake8-pyi==25.5.0
[pip3] flake8_simplify==0.22.0
[pip3] gpytorch==1.14
[pip3] intel-cmplr-lib-ur==2025.1.1
[pip3] intel-openmp==2025.1.1
[pip3] mkl==2025.2.0
[pip3] mkl-include==2025.1.0
[pip3] mkl-static==2025.1.0
[pip3] mypy==1.16.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.16.1
[pip3] onnx-ir==0.1.2
[pip3] onnxruntime==1.22.0
[pip3] onnxscript==0.3.0
[pip3] optimum-executorch==0.0.0.dev0
[pip3] optree==0.17.0
[pip3] pytorch-lightning==2.5.1.post0
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] pytorch_tokenizers==0.1.0
[pip3] pytorch-triton==3.4.0+gitf7888497
[pip3] tbb==2022.1.0
[pip3] tbb-devel==2022.1.0
[pip3] tcmlib==1.3.0
[pip3] torch==2.12.0a0+git9bd00c2
[pip3] torch_geometric==2.4.0
[pip3] torch-mlir==20250607.491
[pip3] torchao==0.15.0+git01374eb58
[pip3] torchaudio==2.7.0
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.0.3
[pip3] torchmultimodal-nightly==2024.4.1
[pip3] torchrl==0.7.2
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.22.0
[pip3] torchx==0.7.0
[pip3] triton==3.6.0+git9844da95
[pip3] tritonbench==0.0.1
[pip3] umf==0.10.0
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] botorch                   0.14.0                   pypi_0    pypi
[conda] executorch                0.8.0a0+a27dd42          pypi_0    pypi
[conda] gpytorch                  1.14                     pypi_0    pypi
[conda] intel-cmplr-lib-ur        2025.1.1                 pypi_0    pypi
[conda] intel-openmp              2025.1.1                 pypi_0    pypi
[conda] magma-cuda124             2.6.1                         1    pytorch
[conda] mkl                       2025.2.0                 pypi_0    pypi
[conda] mkl-include               2025.1.0                 pypi_0    pypi
[conda] mkl-static                2025.1.0                 pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.10.2.21                pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.7.1                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.28.9                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
[conda] optimum-executorch        0.0.0.dev0               pypi_0    pypi
[conda] optree                    0.17.0                   pypi_0    pypi
[conda] pytorch-lightning         2.5.1.post0              pypi_0    pypi
[conda] pytorch-sphinx-theme      0.0.24                   pypi_0    pypi
[conda] pytorch-tokenizers        0.1.0                    pypi_0    pypi
[conda] pytorch-triton            3.4.0+gitf7888497          pypi_0    pypi
[conda] tbb                       2022.1.0                 pypi_0    pypi
[conda] tbb-devel                 2022.1.0                 pypi_0    pypi
[conda] tcmlib                    1.3.0                    pypi_0    pypi
[conda] torch                     2.9.0a0+gitc157cf6          pypi_0    pypi
[conda] torch-geometric           2.4.0                    pypi_0    pypi
[conda] torch-mlir                20250607.491             pypi_0    pypi
[conda] torchao                   0.15.0+git01374eb58           dev_0    <develop>
[conda] torchaudio                2.6.0a0+d60ce09          pypi_0    pypi
[conda] torchdata                 0.11.0                   pypi_0    pypi
[conda] torchdiffeq               0.2.5                     dev_0    <develop>
[conda] torchfix                  0.4.0                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchmultimodal-nightly   2024.4.1                 pypi_0    pypi
[conda] torchrl                   0.7.2                    pypi_0    pypi
[conda] torchsr                   1.0.4                    pypi_0    pypi
[conda] torchtune                 0.6.1                    pypi_0    pypi
[conda] torchvision               0.25.0a0+e3b5d3a           dev_0    <develop>
[conda] torchx                    0.7.0                    pypi_0    pypi
[conda] triton                    3.6.0+git9844da95          pypi_0    pypi
[conda] tritonbench               0.0.1                    pypi_0    pypi
[conda] umf                       0.10.0                   pypi_0    pypi
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

With torchtitan Qwen3 CP: torchrun --nproc_per_node=2 -m torchtitan.train --module qwen3 --config qwen3_debugmodel --parallelism.context_parallel_degree 2

    File "/data/users/pianpwk/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 367, in wrapper
      return f(*args, **kwargs)
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 819, in train
      self.train_step(data_iterator)
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 729, in train_step
      loss = self.forward_backward_step(
    File "/data/users/pianpwk/torchtitan/torchtitan/trainer.py", line 687, in forward_backward_step
      loss.backward()
    File "/data/users/pianpwk/pytorch/torch/_tensor.py", line 631, in backward
      torch.autograd.backward(
    File "/data/users/pianpwk/pytorch/torch/autograd/__init__.py", line 379, in backward
      _engine_run_backward(
    File "/data/users/pianpwk/pytorch/torch/autograd/graph.py", line 877, in _engine_run_backward
      return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 926, in _sdpa_handler
      local_results = call_maps[op_call](
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 873, in _scaled_dot_product_ring_cudnn_attention_backward
      return _templated_ring_attention_backward(
    File "/data/users/pianpwk/pytorch/torch/distributed/tensor/experimental/_context_parallel/_attention.py", line 574, in _templated_ring_attention_backward
      grad_query_, grad_key_, grad_value_, *rest = op(
    File "/data/users/pianpwk/pytorch/torch/_ops.py", line 871, in __call__
      return self._op(*args, **kwargs)
  RuntimeError: same_strides(o, dO_) INTERNAL ASSERT FAILED at "/data/users/pianpwk/pytorch/aten/src/ATen/native/cudnn/MHA.cpp":1628, please report a bug to PyTorch. cuDNN SDPA expected grad_output.strides() == output.strides(), the previous step probably failed to materialize a grad_output with matching strides...
  
============================================================

In pytorch, script:

"""
Minimal repro: cuDNN SDPA backward stride mismatch with Context Parallel.

This is a pre-existing PyTorch bug where cuDNN SDPA backward requires
grad_output.strides() == output.strides(), but CP ring attention backward
produces mismatched strides due to _SDPAMerger arithmetic + torch.chunk/cat.

Run with 2 GPUs (must be H100 or other GPU that supports cuDNN SDPA):

    torchrun --nproc_per_node=2 repro_cudnn_cp_stride.py

Expected error:
    RuntimeError: same_strides(o, dO_) INTERNAL ASSERT FAILED at
    .../aten/src/ATen/native/cudnn/MHA.cpp:1628

Root cause:
    In _templated_ring_attention_backward(), the merged `out` from
    _SDPAMerger has non-standard strides (due to arithmetic ops + chunk/cat),
    while `grad_out` has standard contiguous strides from autograd.
    cuDNN backward asserts these must match. The fix is to add .contiguous()
    for out_ and dout before passing to the cuDNN backward op, similar to
    what's already done for logsumexp on line 568 of _attention.py.
"""

import os

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental._context_parallel import (
    _context_parallel_shard,
    _ContextParallel,
    _enable_context_parallel_dispatcher,
    _HeadTailLoadBalancer,
)
from torch.distributed.tensor.parallel import parallelize_module
from torch.nn.attention import sdpa_kernel, SDPBackend


class SDPAModule(nn.Module):
    """Thin wrapper so we can parallelize_module with _ContextParallel."""

    def forward(self, q, k, v, **kwargs):
        return F.scaled_dot_product_attention(q, k, v, **kwargs)


def main():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    device_mesh = init_device_mesh("cuda", (world_size,))
    device = f"cuda:{rank}"
    dtype = torch.bfloat16

    # Attention dimensions
    bs = 2
    nheads = 8
    seq_len = 1024  # per-rank seq_len after sharding
    head_dim = 32
    seq_dim = 2  # (bs, nheads, seq, head_dim)

    # Create identical tensors on all ranks
    torch.manual_seed(42)
    full_seq = seq_len * world_size
    q = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)
    k = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)
    v = torch.randn(bs, nheads, full_seq, head_dim, device=device, dtype=dtype)

    with torch.no_grad():
        dist.broadcast(q, src=0)
        dist.broadcast(k, src=0)
        dist.broadcast(v, src=0)

    # Set up CP via parallelize_module (same pattern as torchtitan)
    cp_plan = _ContextParallel(
        seq_dim=seq_dim,
        attention_type=_ContextParallel.AttentionType.SDPA,
    )
    attention = SDPAModule()
    attention = parallelize_module(attention, device_mesh, cp_plan)

    # Shard Q, K, V along sequence dim with load balancing
    load_balancer = _HeadTailLoadBalancer(full_seq, world_size, device)
    cp_q, cp_k, cp_v = _context_parallel_shard(
        device_mesh, (q, k, v), (seq_dim,) * 3, load_balancer=load_balancer
    )
    cp_q.requires_grad_(True)
    cp_k.requires_grad_(True)
    cp_v.requires_grad_(True)

    _enable_context_parallel_dispatcher()

    # Force cuDNN backend — this triggers the stride mismatch on backward
    if rank == 0:
        print(f"Running CP ring attention with cuDNN SDPA backend...")
        print(f"  world_size={world_size}, bs={bs}, nheads={nheads}, "
              f"seq_len={full_seq} (per-rank: {seq_len}), head_dim={head_dim}")

    with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
        out = attention(cp_q, cp_k, cp_v, is_causal=True)
        # Backward triggers the stride mismatch assertion in cuDNN
        out.sum().backward()

    if rank == 0:
        print("SUCCESS — no stride mismatch (bug may be fixed!)")


if __name__ == "__main__":
    main()

Versions

Collecting environment information...
PyTorch version: 2.12.0a0+git9bd00c2
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64)
GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-14)
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.34

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.4.3-0_fbk15_hardened_2630_gf27365f948db-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA PG509-210
GPU 1: NVIDIA PG509-210
GPU 2: NVIDIA PG509-210
GPU 3: NVIDIA PG509-210
GPU 4: NVIDIA PG509-210
GPU 5: NVIDIA PG509-210
GPU 6: NVIDIA PG509-210
GPU 7: NVIDIA PG509-210

Nvidia driver version: 550.90.07
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.9.1.0
/usr/lib64/libcudnn_adv.so.9.1.0
/usr/lib64/libcudnn_cnn.so.9.1.0
/usr/lib64/libcudnn_engines_precompiled.so.9.1.0
/usr/lib64/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib64/libcudnn_graph.so.9.1.0
/usr/lib64/libcudnn_heuristic.so.9.1.0
/usr/lib64/libcudnn_ops.so.9.1.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             168
On-line CPU(s) list:                0-167
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 21
Socket(s):                          4
Stepping:                           11
BogoMIPS:                           3591.76
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat vnmi umip pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          5.3 MiB (168 instances)
L1i cache:                          5.3 MiB (168 instances)
L2 cache:                           336 MiB (84 instances)
L3 cache:                           64 MiB (4 instances)
NUMA node(s):                       4
NUMA node0 CPU(s):                  0-41
NUMA node1 CPU(s):                  42-83
NUMA node2 CPU(s):                  84-125
NUMA node3 CPU(s):                  126-167
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Vulnerable
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:           Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Vulnerable
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] bert_pytorch==0.0.1a4
[pip3] botorch==0.14.0
[pip3] executorch==0.8.0a0+a27dd42
[pip3] flake8==7.3.0
[pip3] flake8-bugbear==24.12.12
[pip3] flake8-comprehensions==3.16.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==2024.24.12
[pip3] flake8-pyi==25.5.0
[pip3] flake8_simplify==0.22.0
[pip3] gpytorch==1.14
[pip3] intel-cmplr-lib-ur==2025.1.1
[pip3] intel-openmp==2025.1.1
[pip3] mkl==2025.2.0
[pip3] mkl-include==2025.1.0
[pip3] mkl-static==2025.1.0
[pip3] mypy==1.16.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.28.9
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] onnx==1.16.1
[pip3] onnx-ir==0.1.2
[pip3] onnxruntime==1.22.0
[pip3] onnxscript==0.3.0
[pip3] optimum-executorch==0.0.0.dev0
[pip3] optree==0.17.0
[pip3] pytorch-lightning==2.5.1.post0
[pip3] pytorch_sphinx_theme==0.0.24
[pip3] pytorch_tokenizers==0.1.0
[pip3] pytorch-triton==3.4.0+gitf7888497
[pip3] tbb==2022.1.0
[pip3] tbb-devel==2022.1.0
[pip3] tcmlib==1.3.0
[pip3] torch==2.12.0a0+git9bd00c2
[pip3] torch_geometric==2.4.0
[pip3] torch-mlir==20250607.491
[pip3] torchao==0.15.0+git01374eb58
[pip3] torchaudio==2.7.0
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.0.3
[pip3] torchmultimodal-nightly==2024.4.1
[pip3] torchrl==0.7.2
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.22.0
[pip3] torchx==0.7.0
[pip3] triton==3.6.0+git9844da95
[pip3] tritonbench==0.0.1
[pip3] umf==0.10.0
[conda] bert-pytorch              0.0.1a4                   dev_0    <develop>
[conda] botorch                   0.14.0                   pypi_0    pypi
[conda] executorch                0.8.0a0+a27dd42          pypi_0    pypi
[conda] gpytorch                  1.14                     pypi_0    pypi
[conda] intel-cmplr-lib-ur        2025.1.1                 pypi_0    pypi
[conda] intel-openmp              2025.1.1                 pypi_0    pypi
[conda] magma-cuda124             2.6.1                         1    pytorch
[conda] mkl                       2025.2.0                 pypi_0    pypi
[conda] mkl-include               2025.1.0                 pypi_0    pypi
[conda] mkl-static                2025.1.0                 pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.6.4.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.6.80                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.6.77                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.6.77                  pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.10.2.21                pypi_0    pypi
[conda] nvidia-cufft-cu12         11.3.0.4                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.7.77                pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.7.1.2                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.5.4.2                 pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.7.1                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.28.9                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.6.85                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.6.77                  pypi_0    pypi
[conda] optimum-executorch        0.0.0.dev0               pypi_0    pypi
[conda] optree                    0.17.0                   pypi_0    pypi
[conda] pytorch-lightning         2.5.1.post0              pypi_0    pypi
[conda] pytorch-sphinx-theme      0.0.24                   pypi_0    pypi
[conda] pytorch-tokenizers        0.1.0                    pypi_0    pypi
[conda] pytorch-triton            3.4.0+gitf7888497          pypi_0    pypi
[conda] tbb                       2022.1.0                 pypi_0    pypi
[conda] tbb-devel                 2022.1.0                 pypi_0    pypi
[conda] tcmlib                    1.3.0                    pypi_0    pypi
[conda] torch                     2.9.0a0+gitc157cf6          pypi_0    pypi
[conda] torch-geometric           2.4.0                    pypi_0    pypi
[conda] torch-mlir                20250607.491             pypi_0    pypi
[conda] torchao                   0.15.0+git01374eb58           dev_0    <develop>
[conda] torchaudio                2.6.0a0+d60ce09          pypi_0    pypi
[conda] torchdata                 0.11.0                   pypi_0    pypi
[conda] torchdiffeq               0.2.5                     dev_0    <develop>
[conda] torchfix                  0.4.0                    pypi_0    pypi
[conda] torchmetrics              1.0.3                    pypi_0    pypi
[conda] torchmultimodal-nightly   2024.4.1                 pypi_0    pypi
[conda] torchrl                   0.7.2                    pypi_0    pypi
[conda] torchsr                   1.0.4                    pypi_0    pypi
[conda] torchtune                 0.6.1                    pypi_0    pypi
[conda] torchvision               0.25.0a0+e3b5d3a           dev_0    <develop>
[conda] torchx                    0.7.0                    pypi_0    pypi
[conda] triton                    3.6.0+git9844da95          pypi_0    pypi
[conda] tritonbench               0.0.1                    pypi_0    pypi
[conda] umf                       0.10.0                   pypi_0    pypi

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

Fix Plan

1. Modify the forward_backward_step method in trainer.py to add .contiguous() for out_ and dout before passing to the cuDNN backward op.

def forward_backward_step(self):
    # ...
    out_ = self.forward_step(data_iterator)
    # Add .contiguous() to ensure contiguous strides
    out_ = out_.contiguous()
    dout = torch.zeros_like(out_)
    dout = dout.contiguous()  # Add .contiguous() here
    loss = self.backward_step(out_, dout)
    # ...

2. Update the forward_backward_step method to use the cuDNN SDPA backend with contiguous strides.

def forward_backward_step(self):
    # ...
    with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
        out_ = self.forward_step(data_iterator)
        out_ = out_.contiguous()
        dout = torch.zeros_like(out_)
        dout = dout.contiguous()
        loss = self.backward_step(out_, dout)
    # ...

3. Verify that the fix works by running the script with the updated forward_backward_step method and checking that the error is resolved.

if rank == 0:
    print("SUCCESS — no stride mismatch (bug may be fixed!)")

Verification

  1. Run the script with the updated forward_backward_step method.
  2. Check that the error is resolved and the script completes successfully.
  3. Verify that the output is correct and matches the expected result.

Extra Tips

  • Make sure to update the forward_backward_step method in all relevant places in the codebase.
  • Verify that the fix does not introduce any new issues or regressions.
  • Consider adding additional checks or assertions to ensure that the contiguous strides are maintained throughout the computation.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [distributed] CUDNN SDPA backend + CP stride failure [1 pull requests, 6 comments, 6 participants]