pytorch - ✅(Solved) Fix [MPS] scaled_dot_product_attention returns wrong values for non-contiguous (permute-produced) q/k/v [2 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181133Fetched 2026-04-23 07:22:29
View on GitHub
Comments
2
Participants
2
Timeline
25
Reactions
0
Assignees
Timeline (top)
mentioned ×8subscribed ×8labeled ×3commented ×2

Fix Action

Workaround

`.contiguous()` after the permute before calling SDPA.

PR fix notes

PR #888: [PRI-277] Fix v2.6 low accuracy on MPS

Description (problem / solution / changelog)

Summary

On MPS, the default v2.6 classifier/regressor produced near-random outputs (e.g. Iris multiclass accuracy 0.48 vs 0.98 on CPU). Root cause is a PyTorch MPS bug: F.scaled_dot_product_attention silently returns wrong values when given non-contiguous q/k/v produced by a permute. Every v2.6 attention block passes permute views straight into SDPA, so the bug fires on every forward.

v2.5 isn't affected because its default checkpoint loads via the older PerFeatureTransformer, which uses a different attention implementation.

Verified against both task types — the fix resolves the accuracy regression for classification and regression (details below).

Fix

Two-line guard in _batched_scaled_dot_product_attention that calls .contiguous() on the permuted q/k/v only when the device is MPS. CUDA and CPU paths are untouched.

Upstream PyTorch bug with minimal reproducer: pytorch/pytorch#181133. Once fixed upstream, we can drop this guard (gated on torch/macOS version).

Resolves PRI-277.

Verification

Both examples were run on MPS (patched) and compared against CPU (TABPFN_EXCLUDE_DEVICES=mps):

  • Classificationexamples/tabpfn_for_multiclass_classification.py on MPS: Accuracy 0.48 → 0.98, matches CPU.
  • Regressionexamples/tabpfn_for_regression.py on MPS: R² and all MAE values match CPU to 4+ decimals (pre-fix: MSE ~10% worse on MPS per the ticket).
  • tests/test_model/test_attention.py — all pass.

Test plan

  • CI green on macOS runner (if one exists)
  • Spot-check Iris / regression example on a reviewer's Mac

🤖 Generated with Claude Code

Changed files

  • changelog/888.fixed.md (added, +1/-0)
  • src/tabpfn/architectures/tabpfn_v2_6.py (modified, +18/-3)

PR #181151: [MPS] Fix SDPA wrong output for permuted q/k/v with B > 1

Description (problem / solution / changelog)

Fixes #181133

Changed files

  • aten/src/ATen/native/mps/kernels/Attention.metal (modified, +53/-37)
  • aten/src/ATen/native/mps/operations/Attention.mm (modified, +15/-14)
  • test/test_mps.py (modified, +22/-2)

Code Example

import torch

assert torch.backends.mps.is_available()
torch.manual_seed(0)

B, S, H, D = 84, 3, 3, 64
q = torch.randn(B, S, H, D)
k = torch.randn(B, S, H, D)
v = torch.randn(B, S, H, D)


def call(q, k, v, make_contiguous: bool):
    q = q.permute(0, 2, 1, 3)
    k = k.permute(0, 2, 1, 3)
    v = v.permute(0, 2, 1, 3)
    if make_contiguous:
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
    return torch.nn.functional.scaled_dot_product_attention(q, k, v)


for mk in (False, True):
    o_cpu = call(q, k, v, mk)
    o_mps = call(q.to("mps"), k.to("mps"), v.to("mps"), mk).cpu()
    print(f"contiguous={mk}: max abs err = {(o_cpu - o_mps).abs().max().item():.4g}")

---

contiguous=False: max abs err = 3.618
contiguous=True:  max abs err = 7.749e-07
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

On MPS, F.scaled_dot_product_attention returns numerically incorrect results when given non-contiguous tensors produced by permute(0, 2, 1, 3). Inserting .contiguous() after the permute restores correct output.

Distinct from #179352, which requires total elements > 2³². This fires at tiny shapes (~5M elements).

Minimal reproduction

import torch

assert torch.backends.mps.is_available()
torch.manual_seed(0)

B, S, H, D = 84, 3, 3, 64
q = torch.randn(B, S, H, D)
k = torch.randn(B, S, H, D)
v = torch.randn(B, S, H, D)


def call(q, k, v, make_contiguous: bool):
    q = q.permute(0, 2, 1, 3)
    k = k.permute(0, 2, 1, 3)
    v = v.permute(0, 2, 1, 3)
    if make_contiguous:
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
    return torch.nn.functional.scaled_dot_product_attention(q, k, v)


for mk in (False, True):
    o_cpu = call(q, k, v, mk)
    o_mps = call(q.to("mps"), k.to("mps"), v.to("mps"), mk).cpu()
    print(f"contiguous={mk}: max abs err = {(o_cpu - o_mps).abs().max().item():.4g}")
contiguous=False: max abs err = 3.618
contiguous=True:  max abs err = 7.749e-07

Bug regime (from a dimension sweep)

Triggers when all of:

  • `head_dim ≥ 64`
  • `seq_len ∈ [2, 8]` (works at `S=1`; works again at `S ≥ 16`)
  • `batch ≥ 2`

Smaller `head_dim` or outside this sequence-length band produces correct output with the same permute pattern.

Workaround

`.contiguous()` after the permute before calling SDPA.

Versions

  • `torch==2.11.0`
  • macOS 26.3.1 / arm64

Related

  • #179352 — large batch × seq case (different regime, same op/backend)
  • #179294 — MPS SDPA improvements umbrella

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk @drisspg @liangel-02 @howardzhang-cv

extent analysis

TL;DR

Inserting .contiguous() after the permute operation can fix the numerically incorrect results issue with F.scaled_dot_product_attention on MPS.

Guidance

  • The issue is triggered when the input tensors are non-contiguous, which can be resolved by making them contiguous using the .contiguous() method.
  • To verify the fix, compare the output of F.scaled_dot_product_attention with and without the .contiguous() call, as shown in the minimal reproduction code.
  • The workaround is specifically applicable when head_dim ≥ 64, seq_len ∈ [2, 8], and batch ≥ 2.
  • The issue is distinct from #179352, which requires total elements > 2³², and this fix may not apply to that case.

Example

q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()

Notes

  • The issue is specific to the MPS backend and may not occur on other backends.
  • The fix may have performance implications, as making tensors contiguous can involve copying data.

Recommendation

Apply workaround: Inserting .contiguous() after the permute operation can fix the issue, and it is a simple and targeted solution that does not require upgrading the torch version.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS] scaled_dot_product_attention returns wrong values for non-contiguous (permute-produced) q/k/v [2 pull requests, 2 comments, 2 participants]