pytorch - 💡(How to fix) Fix [MPS] nn.MultiheadAttention is ~9x slower than direct F.scaled_dot_product_attention with bit-identical output (B=1) [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181725Fetched 2026-04-29 06:11:12
View on GitHub
Comments
0
Participants
1
Timeline
82
Reactions
0
Author
Participants
Timeline (top)
mentioned ×33subscribed ×33unsubscribed ×11labeled ×5

Root Cause

The slow path is triggered by nn.MultiheadAttention.forward doing query.transpose(1, 0) when batch_first=True (activation.py:1465). At B=1, this is a stride-only view: is_contiguous() reports True, but the stride pattern becomes (D, B*D, 1) instead of native (D, D, 1). Subsequent F.linear calls (4 per forward — Q, K, V projections plus output projection) then hit a slow path on MPS for this stride pattern.

Isolating F.linear alone with B=1, L=32, D=512 on MPS:

InputStrideF.linear time
(B,L,D) native(L*D, D, 1)30 μs
(L,B,D) native(B*D, D, 1)32 μs
(B,L,D).transpose(0,1) view(D, B*D, 1)103 μs
(B,L,D).transpose(0,1).contiguous()(D, B*D, 1)99 μscontiguous() is a no-op at B=1
(B,L,D).transpose(0,1).reshape(L,B,D)(B*D, D, 1)31 μs ← stride normalized

contiguous() is a no-op here because is_contiguous() already returns True even though the stride is the slow one — only reshape (or an equivalent stride-normalizing op) recovers the fast layout.

Fix Action

Fix / Workaround

  • The gap is specific to MPS (CPU shows ~1×).
  • The gap is largest at B=1 and shrinks as batch grows (1.3× at B=4) — consistent with per-launch dispatch overhead.
  • The gap is not specific to cross-attention; self-attention shows the same pattern (6.6× at B=1 BERT-ish self-attn).
  • Both paths produce bit-identical outputs, so an optimization would be behavior-preserving.

The deeper question is whether the right fix is in activation.py / functional.py (workaround-style) or in the MPS F.linear backend (root cause). Happy to send a PR for either direction once there is guidance.

Code Example

import time, statistics, torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
device = "mps"
B, Lq, Lkv, D, H = 1, 32, 256, 512, 8
hd = D // H

mha = nn.MultiheadAttention(D, H, batch_first=True).to(device).eval()
q = torch.randn(B, Lq, D, device=device)
mem = torch.randn(B, Lkv, D, device=device)

with torch.no_grad():
    Wq, Wk, Wv = mha.in_proj_weight[:D], mha.in_proj_weight[D:2*D], mha.in_proj_weight[2*D:]
    bq, bk, bv = mha.in_proj_bias[:D], mha.in_proj_bias[D:2*D], mha.in_proj_bias[2*D:]
    Wo, bo = mha.out_proj.weight, mha.out_proj.bias

def via_mha():
    return mha(q, mem, mem, need_weights=False)[0]

def via_sdpa():
    Q = F.linear(q, Wq, bq).view(B, Lq, H, hd).transpose(1, 2)
    K = F.linear(mem, Wk, bk).view(B, Lkv, H, hd).transpose(1, 2)
    V = F.linear(mem, Wv, bv).view(B, Lkv, H, hd).transpose(1, 2)
    out = F.scaled_dot_product_attention(Q, K, V).transpose(1, 2).contiguous().view(B, Lq, D)
    return F.linear(out, Wo, bo)

with torch.no_grad():
    print(f"max|MHA - SDPA| = {(via_mha() - via_sdpa()).abs().max().item():.2e}")

def bench(fn, iters=300, warmup=30):
    with torch.no_grad():
        for _ in range(warmup): fn()
        torch.mps.synchronize()
        t0 = time.perf_counter()
        for _ in range(iters): fn()
        torch.mps.synchronize()
    return (time.perf_counter() - t0) / iters * 1e6

t_mha = statistics.median([bench(via_mha) for _ in range(3)])
t_sdpa = statistics.median([bench(via_sdpa) for _ in range(3)])
print(f"nn.MultiheadAttention: {t_mha:7.1f} us")
print(f"F.scaled_dot_product_attention: {t_sdpa:7.1f} us")
print(f"Slowdown: {t_mha/t_sdpa:.1f}x")

---

max|MHA - SDPA| = 0.00e+00
nn.MultiheadAttention: 1199.1 us
F.scaled_dot_product_attention:  139.1 us
Slowdown: 8.6x
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

On MPS, nn.MultiheadAttention(..., batch_first=True) is ~9× slower than calling F.scaled_dot_product_attention directly with the same projection weights, even with need_weights=False (i.e. when MultiheadAttention already routes to SDPA internally).

The two paths produce bit-identical outputs (max|diff| = 0.00e+00), so this is purely overhead in the wrapper / functional layer, not a difference in the math. The same comparison on CPU shows ~1× (no gap).

This is not a duplicate of #100347. That issue is about need_weights=True falling off the SDPA path entirely. This issue is about the residual ~9× gap that remains on MPS after setting need_weights=False and reaching the SDPA branch.

Reproducer

import time, statistics, torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
device = "mps"
B, Lq, Lkv, D, H = 1, 32, 256, 512, 8
hd = D // H

mha = nn.MultiheadAttention(D, H, batch_first=True).to(device).eval()
q = torch.randn(B, Lq, D, device=device)
mem = torch.randn(B, Lkv, D, device=device)

with torch.no_grad():
    Wq, Wk, Wv = mha.in_proj_weight[:D], mha.in_proj_weight[D:2*D], mha.in_proj_weight[2*D:]
    bq, bk, bv = mha.in_proj_bias[:D], mha.in_proj_bias[D:2*D], mha.in_proj_bias[2*D:]
    Wo, bo = mha.out_proj.weight, mha.out_proj.bias

def via_mha():
    return mha(q, mem, mem, need_weights=False)[0]

def via_sdpa():
    Q = F.linear(q, Wq, bq).view(B, Lq, H, hd).transpose(1, 2)
    K = F.linear(mem, Wk, bk).view(B, Lkv, H, hd).transpose(1, 2)
    V = F.linear(mem, Wv, bv).view(B, Lkv, H, hd).transpose(1, 2)
    out = F.scaled_dot_product_attention(Q, K, V).transpose(1, 2).contiguous().view(B, Lq, D)
    return F.linear(out, Wo, bo)

with torch.no_grad():
    print(f"max|MHA - SDPA| = {(via_mha() - via_sdpa()).abs().max().item():.2e}")

def bench(fn, iters=300, warmup=30):
    with torch.no_grad():
        for _ in range(warmup): fn()
        torch.mps.synchronize()
        t0 = time.perf_counter()
        for _ in range(iters): fn()
        torch.mps.synchronize()
    return (time.perf_counter() - t0) / iters * 1e6

t_mha = statistics.median([bench(via_mha) for _ in range(3)])
t_sdpa = statistics.median([bench(via_sdpa) for _ in range(3)])
print(f"nn.MultiheadAttention: {t_mha:7.1f} us")
print(f"F.scaled_dot_product_attention: {t_sdpa:7.1f} us")
print(f"Slowdown: {t_mha/t_sdpa:.1f}x")

Output (PyTorch nightly 2.13.0.dev20260427, M4 Pro, macOS 26.3)

max|MHA - SDPA| = 0.00e+00
nn.MultiheadAttention: 1199.1 us
F.scaled_dot_product_attention:  139.1 us
Slowdown: 8.6x

Sweep across shapes (MPS, all B=1 unless noted)

ShapeMHA (μs)SDPA (μs)Gap
Lq=32, Lkv=256, D=512, H=811991398.6×
Lq=64, Lkv=512, D=512, H=8241921711.2×
Lq=128, Lkv=1024, D=512, H=848735459.0×
Lq=64, Lkv=64, D=768, H=12 (BERT-base-ish)9661576.2×
Same Lq=32,Lkv=256, but B=44773641.3×
Same Lq=32,Lkv=256, on CPU instead of MPS3643900.9×

max|MHA - SDPA| = 0 for every row above (no semantic difference between the two paths).

Observations

  • The gap is specific to MPS (CPU shows ~1×).
  • The gap is largest at B=1 and shrinks as batch grows (1.3× at B=4) — consistent with per-launch dispatch overhead.
  • The gap is not specific to cross-attention; self-attention shows the same pattern (6.6× at B=1 BERT-ish self-attn).
  • Both paths produce bit-identical outputs, so an optimization would be behavior-preserving.

Root cause

The slow path is triggered by nn.MultiheadAttention.forward doing query.transpose(1, 0) when batch_first=True (activation.py:1465). At B=1, this is a stride-only view: is_contiguous() reports True, but the stride pattern becomes (D, B*D, 1) instead of native (D, D, 1). Subsequent F.linear calls (4 per forward — Q, K, V projections plus output projection) then hit a slow path on MPS for this stride pattern.

Isolating F.linear alone with B=1, L=32, D=512 on MPS:

InputStrideF.linear time
(B,L,D) native(L*D, D, 1)30 μs
(L,B,D) native(B*D, D, 1)32 μs
(B,L,D).transpose(0,1) view(D, B*D, 1)103 μs
(B,L,D).transpose(0,1).contiguous()(D, B*D, 1)99 μscontiguous() is a no-op at B=1
(B,L,D).transpose(0,1).reshape(L,B,D)(B*D, D, 1)31 μs ← stride normalized

contiguous() is a no-op here because is_contiguous() already returns True even though the stride is the slow one — only reshape (or an equivalent stride-normalizing op) recovers the fast layout.

Candidate fix

Inserting .reshape(L, B, D) after the transpose at activation.py:1465 normalizes the stride at B=1 without copying memory, and closes the gap to ~1× across all the shapes I tested, with bit-identical output:

Shapeorig/handfix/handmax|diff|
small cross (B=1)7.8×1.0×0
medium cross (B=1)11.1×1.0×0
large cross (B=1)10.1×1.0×0
BERT-ish cross (B=1)6.8×1.1×0
BERT-ish self (B=1)7.0×1.0×0
small self (B=1)1.8×1.1×0
B=4 cross1.3×1.1×0

The deeper question is whether the right fix is in activation.py / functional.py (workaround-style) or in the MPS F.linear backend (root cause). Happy to send a PR for either direction once there is guidance.

Versions

  • PyTorch: 2.13.0.dev20260427 (nightly, also reproduces on 2.11.0)
  • Build: e538a0b59d1d2f0cc0526c9ed8bde7df8112572f
  • Device: MPS, Apple M4 Pro, 24 GB
  • macOS: 26.3 (build 25D125)
  • Python: 3.12

extent analysis

TL;DR

Inserting .reshape(L, B, D) after the transpose in nn.MultiheadAttention.forward normalizes the stride and closes the performance gap to ~1×.

Guidance

  • The slowdown is caused by the stride pattern becoming non-native after query.transpose(1, 0) when batch_first=True.
  • Isolating F.linear alone shows that the slow path is hit for the non-native stride pattern.
  • Using .reshape(L, B, D) after the transpose normalizes the stride without copying memory.
  • The fix should be applied in activation.py or functional.py to workaround the issue, or in the MPS F.linear backend to address the root cause.

Example

# In activation.py:1465
query = query.transpose(1, 0).reshape(L, B, D)  # Add .reshape(L, B, D)

Notes

The fix assumes that the issue is specific to the MPS backend and the batch_first=True case. Further testing may be needed to confirm that this fix does not introduce any regressions.

Recommendation

Apply the workaround by inserting .reshape(L, B, D) after the transpose in nn.MultiheadAttention.forward, as this is a behavior-preserving optimization that closes the performance gap to ~1×.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [MPS] nn.MultiheadAttention is ~9x slower than direct F.scaled_dot_product_attention with bit-identical output (B=1) [1 participants]