pytorch - ✅(Solved) Fix MPS: scaled_dot_product_attention produces incorrect results for large batch × sequence length combinations [1 pull requests, 4 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179352Fetched 2026-04-08 02:43:32
View on GitHub
Comments
4
Participants
2
Timeline
31
Reactions
0
Timeline (top)
mentioned ×9subscribed ×9labeled ×6commented ×4

PR fix notes

PR #179592: Fix MPS SDPA correctness when score matrix exceeds 2^32 elements

Description (problem / solution / changelog)

Summary

Fixes #179352. Apple's MPSGraph corrupts SDPA outputs when the score matrix B*H*Nq*Nkv exceeds 2^32 total elements: at the canary shape (B=1, H=8, Nq=16384, Nkv=65536, D=64), output diverges from CPU with cosine similarity ~0.44 and max relative error ~1.2. The threshold is independent of dtype, of is_causal, and of whether an explicit attn_mask is supplied — the bug is in the matmul/softmax sequence underneath all four MPSGraph codepaths.

The bug is reachable via sdpa_general_mps in aten/src/ATen/native/mps/operations/Attention.mm, which builds an MPSGraph of the form matrixMultiplicationWithPrimaryTensor (Q·Kᵀ) → optional causal/additive mask → softMaxWithTensormatrixMultiplicationWithPrimaryTensor (·V). The threshold shape (element-counted, not byte-counted, exactly at 2^32) is consistent with a 32-bit flat element index somewhere inside Apple's matmul or softmax kernel, but the bug has not been confirmed against Apple source. PyTorch's own custom Metal kernels in aten/src/ATen/native/mps/kernels/Attention.metal are not involved for shapes that hit this bug — _scaled_dot_product_attention_math_mps only dispatches to those when query_seq_len <= 8, and the affected shapes have much larger Q.

Note that Attention.mm also defines a third path, sdpa_full_attention_mps, that wires up the custom attention Metal kernel and would not go through MPSGraph at all — but it is dead code on main, never called from the dispatcher (the test_fast_full_attention_* tests are correspondingly all skipped with "Full attention fast kernel not implemented yet"). Enabling it was therefore not a viable fix for this PR.

Since SDPA's softmax is row-wise over the Q axis, splitting Q into chunks and concatenating outputs is mathematically exact. sdpa_general_mps now factors its existing graph build/run into sdpa_general_mps_impl and wraps it: when the score matrix is small enough (≤ 2^32 elements) it calls the impl once, preserving the original codepath byte-for-byte; otherwise it loops over Q slices of size max(1, 2^32 / (B*H*Nkv)) and stitches the outputs back together. For is_causal=True, the wrapper synthesizes a per-chunk additive causal mask from the chunk's Q offset and passes it to the impl with is_causal=false (the in-graph causal path assumes the chunk starts at row 0). User-supplied attn_mask is narrowed along the Q axis. A TORCH_WARN_ONCE guards the unrealistic edge case where B*H*Nkv alone exceeds 2^32.

Suggested review order: read the new sdpa_general_mps_impl first (it's the unchanged old body), then the new sdpa_general_mps wrapper, then the test.

Test plan

  • `python test/test_mps.py -k attention -v` — 96 tests, all pass (skipped=24, expected failures=10, identical to baseline)
  • `python test/test_mps.py -k test_sdpa_large_score_matrix_179352 -v` — 4/4 pass on a 64 GB Mac (3 variants + meta)
  • Empirical sweep across dtypes (fp32/fp16/bf16), `is_causal=True/False`, and bool/additive masks — all formerly failing shapes now match CPU to ~1e-5 max relative error
  • Lint clean (`spin lint`)
  • Skipped on CI by design (`total_memory >= 64 GB` + `IS_CI` gates)

Authored with the assistance of Claude (Anthropic).

Changed files

  • aten/src/ATen/native/mps/operations/Attention.mm (modified, +86/-12)
  • test/test_mps.py (modified, +38/-0)

Code Example

import torch
import torch.nn.functional as F

assert torch.backends.mps.is_available(), "MPS not available"
torch.manual_seed(42)

def compare_sdpa(B, H, N_q, N_kv, D, label=""):
    q = torch.randn(B, H, N_q, D)
    k = torch.randn(B, H, N_kv, D)
    v = torch.randn(B, H, N_kv, D)

    out_cpu = F.scaled_dot_product_attention(q, k, v)
    out_mps = F.scaled_dot_product_attention(
        q.to("mps"), k.to("mps"), v.to("mps")
    ).cpu()

    max_diff = (out_cpu - out_mps).abs().max().item()
    cos = F.cosine_similarity(
        out_cpu.flatten().unsqueeze(0),
        out_mps.flatten().unsqueeze(0),
    ).item()

    status = "PASS" if cos > 0.999 else "FAIL"
    print(f"[{status}] {label:<35} max_diff={max_diff:>8.4f}  cos_sim={cos:.6f}")

print(f"PyTorch version: {torch.__version__}")
print()

# Small sequences — PASS
compare_sdpa(16, 8, 1024, 2048, 64, "B=16, seq=1024x2048")
compare_sdpa(16, 8, 4096, 8192, 64, "B=16, seq=4096x8192")

# Large sequences — FAIL
compare_sdpa(16, 8, 5120, 10240, 64, "B=16, seq=5120x10240")
compare_sdpa(16, 8, 10240, 20480, 64, "B=16, seq=10240x20480")

# Batch size sweep at large seq — shows threshold
print()
compare_sdpa(1, 8, 10240, 20480, 64, "B=1,  seq=10240x20480")
compare_sdpa(2, 8, 10240, 20480, 64, "B=2,  seq=10240x20480")
compare_sdpa(4, 8, 10240, 20480, 64, "B=4,  seq=10240x20480")
compare_sdpa(8, 8, 10240, 20480, 64, "B=8,  seq=10240x20480")
compare_sdpa(16, 8, 10240, 20480, 64, "B=16, seq=10240x20480")

---

PyTorch version: 2.11.0

[PASS] B=16, seq=1024x2048              max_diff=  0.0000  cos_sim=1.001652
[PASS] B=16, seq=4096x8192              max_diff=  0.0000  cos_sim=1.012778
[FAIL] B=16, seq=5120x10240             max_diff=  0.1821  cos_sim=0.785274
[FAIL] B=16, seq=10240x20480            max_diff=  0.1822  cos_sim=0.488831

[PASS] B=1,  seq=10240x20480            max_diff=  0.0000  cos_sim=1.000803
[PASS] B=2,  seq=10240x20480            max_diff=  0.0000  cos_sim=1.002282
[FAIL] B=4,  seq=10240x20480            max_diff=  0.1011  cos_sim=0.782390
[FAIL] B=8,  seq=10240x20480            max_diff=  0.1545  cos_sim=0.580331
[FAIL] B=16, seq=10240x20480            max_diff=  0.1007  cos_sim=0.485537
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.nn.functional.scaled_dot_product_attention on the MPS backend produces incorrect results when the combination of batch size and sequence length is large. Small batches or short sequences work correctly. The bug appears to be a threshold effect — likely a buffer size or indexing overflow in the MPS SDPA kernel.

Observed behavior:

  • B=1, seq=10240×20480: correct (cosine similarity ~1.0 vs CPU)
  • B=4, seq=10240×20480: wrong (cosine similarity ~0.78 vs CPU)
  • B=16, seq=10240×20480: wrong (cosine similarity ~0.49 vs CPU)
  • B=16, seq=4096×8192: correct
  • B=16, seq=5120×10240: wrong

Expected behavior: MPS SDPA should produce the same results as CPU SDPA regardless of batch/sequence size.

Impact: This causes severe quality degradation in video diffusion models (e.g., ToonCrafter, DynamiCrafter) where the VAE decoder uses cross-attention with spatially-flattened frame tokens (10240 = 80×128 spatial resolution, 20480 = 2× reference frames). The model produces visibly degraded output on MPS while CPU produces correct results.

Minimal reproduction

import torch
import torch.nn.functional as F

assert torch.backends.mps.is_available(), "MPS not available"
torch.manual_seed(42)

def compare_sdpa(B, H, N_q, N_kv, D, label=""):
    q = torch.randn(B, H, N_q, D)
    k = torch.randn(B, H, N_kv, D)
    v = torch.randn(B, H, N_kv, D)

    out_cpu = F.scaled_dot_product_attention(q, k, v)
    out_mps = F.scaled_dot_product_attention(
        q.to("mps"), k.to("mps"), v.to("mps")
    ).cpu()

    max_diff = (out_cpu - out_mps).abs().max().item()
    cos = F.cosine_similarity(
        out_cpu.flatten().unsqueeze(0),
        out_mps.flatten().unsqueeze(0),
    ).item()

    status = "PASS" if cos > 0.999 else "FAIL"
    print(f"[{status}] {label:<35} max_diff={max_diff:>8.4f}  cos_sim={cos:.6f}")

print(f"PyTorch version: {torch.__version__}")
print()

# Small sequences — PASS
compare_sdpa(16, 8, 1024, 2048, 64, "B=16, seq=1024x2048")
compare_sdpa(16, 8, 4096, 8192, 64, "B=16, seq=4096x8192")

# Large sequences — FAIL
compare_sdpa(16, 8, 5120, 10240, 64, "B=16, seq=5120x10240")
compare_sdpa(16, 8, 10240, 20480, 64, "B=16, seq=10240x20480")

# Batch size sweep at large seq — shows threshold
print()
compare_sdpa(1, 8, 10240, 20480, 64, "B=1,  seq=10240x20480")
compare_sdpa(2, 8, 10240, 20480, 64, "B=2,  seq=10240x20480")
compare_sdpa(4, 8, 10240, 20480, 64, "B=4,  seq=10240x20480")
compare_sdpa(8, 8, 10240, 20480, 64, "B=8,  seq=10240x20480")
compare_sdpa(16, 8, 10240, 20480, 64, "B=16, seq=10240x20480")

Output

PyTorch version: 2.11.0

[PASS] B=16, seq=1024x2048              max_diff=  0.0000  cos_sim=1.001652
[PASS] B=16, seq=4096x8192              max_diff=  0.0000  cos_sim=1.012778
[FAIL] B=16, seq=5120x10240             max_diff=  0.1821  cos_sim=0.785274
[FAIL] B=16, seq=10240x20480            max_diff=  0.1822  cos_sim=0.488831

[PASS] B=1,  seq=10240x20480            max_diff=  0.0000  cos_sim=1.000803
[PASS] B=2,  seq=10240x20480            max_diff=  0.0000  cos_sim=1.002282
[FAIL] B=4,  seq=10240x20480            max_diff=  0.1011  cos_sim=0.782390
[FAIL] B=8,  seq=10240x20480            max_diff=  0.1545  cos_sim=0.580331
[FAIL] B=16, seq=10240x20480            max_diff=  0.1007  cos_sim=0.485537

Versions

PyTorch version: 2.11.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 15.7.5 (arm64) GCC version: Could not collect Clang version: 17.0.0 (clang-1700.6.3.2) CMake version: version 4.2.1 Libc version: N/A

Python version: 3.11.15 (main, Mar 11 2026, 17:14:47) [Clang 20.1.8 ] (64-bit runtime) Python platform: macOS-15.7.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M3 Ultra

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.11.0 [pip3] torchvision==0.26.0

cc @malfet @aditvenk @kulinseth @DenisVieriu97 @jhavukainen @drisspg @liangel-02 @howardzhang-cv

extent analysis

TL;DR

The issue can be mitigated by reducing the batch size or sequence length when using torch.nn.functional.scaled_dot_product_attention on the MPS backend.

Guidance

  • The problem appears to be related to a buffer size or indexing overflow in the MPS SDPA kernel, which is triggered by large batch sizes and sequence lengths.
  • To verify the issue, run the provided minimal reproduction code with different batch sizes and sequence lengths to observe the threshold effect.
  • As a temporary workaround, consider reducing the batch size or sequence length to avoid the threshold effect.
  • Further investigation is needed to determine the exact cause and develop a permanent fix.

Example

No code snippet is provided as the issue is related to a specific PyTorch function and hardware backend, and the minimal reproduction code is already provided in the issue.

Notes

The issue is specific to the MPS backend and PyTorch version 2.11.0, and may not be applicable to other versions or backends. The exact cause of the issue is still unknown and requires further investigation.

Recommendation

Apply workaround: Reduce batch size or sequence length to avoid the threshold effect, as this is the most straightforward way to mitigate the issue until a permanent fix is developed.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING