transformers - ✅(Solved) Fix [MPS] Upstream correctness issue in attention when value head dim differs from query [1 pull requests, 7 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44554Fetched 2026-04-08 00:27:44
View on GitHub
Comments
7
Participants
3
Timeline
20
Reactions
0
Author
Timeline (top)
commented ×7mentioned ×5subscribed ×5cross-referenced ×1

Fix Action

Fix / Workaround

I don't know if different dims for query and value is a common use-case for Hugging Face models, but I wanted to notify you of the issue regardless. I'm happy to provide a workaround if you think this issue should be fixed in transformers. Potential workaround is to resize the output tensor and wrangle the strides, as the calculation itself is good for the first (expected) numels.

PR fix notes

PR #176843: [MPS] Fix SDPA output shape when value head dim differs

Description (problem / solution / changelog)

This fixes MPS SDPA output shape for cases where value.size(-1) != query.size(-1), so output now follows (..., L, Ev) as expected. I also added guards in Metal kernel paths that assume equal qkv head dims.

Added the updated meta shape inference for the sdpa_general_mps path which seems to have been left out initially.

Added regression coverage in test/test_transformers.py covering the shape semantics, and a similar one in test/test_mps.py that also checks for numerical parity with CPU.

Fixes #176767

Changed files

  • aten/src/ATen/native/mps/operations/Attention.mm (modified, +25/-8)
  • test/test_mps.py (modified, +20/-0)
  • test/test_transformers.py (modified, +19/-0)
  • torch/_meta_registrations.py (modified, +42/-4)

Code Example

import torch
import torch.nn.functional as F

q = torch.rand(1, 1, 8, 4, device="mps")
k = torch.rand(1, 1, 8, 4, device="mps")
v = torch.rand(1, 1, 8, 2, device="mps")

y_mps1 = F.scaled_dot_product_attention(q, k, v)
y_mps2 = F.scaled_dot_product_attention(q, k, v)

print(f"{y_mps1-y_mps2 = }")

y_cpu1 = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu())
y_cpu2 = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu())

print(f"{y_cpu1-y_cpu2 = }")

print(f"{y_mps1.shape = }")
print(f"{y_cpu1.shape = }")

# Output:
# y_mps1-y_mps2 = tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.0000,  0.0000],
#           [-0.0957, -0.1705, -0.1087, -0.1648],
#           [-0.1155, -0.1248, -0.0946, -0.1255],
#           [-0.1089, -0.1662, -0.1261, -0.1559],
#           [-0.0904, -0.1262, -0.0927, -0.1335],
#           [-0.1080, -0.1644, -0.1226, -0.1611],
#           [-0.0849, -0.1324, -0.0984, -0.1281]]]], device='mps:0')
# y_cpu1-y_cpu2 = tensor([[[[0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.]]]])
# y_mps1.shape = torch.Size([1, 1, 8, 4])
# y_cpu1.shape = torch.Size([1, 1, 8, 2])
RAW_BUFFERClick to expand / collapse

System Info

There is a correctness issue with attention in PyTorch when using the MPS backend with value head dims different from the query head (see https://github.com/pytorch/pytorch/issues/176767).

Consider the following reproducer

import torch
import torch.nn.functional as F

q = torch.rand(1, 1, 8, 4, device="mps")
k = torch.rand(1, 1, 8, 4, device="mps")
v = torch.rand(1, 1, 8, 2, device="mps")

y_mps1 = F.scaled_dot_product_attention(q, k, v)
y_mps2 = F.scaled_dot_product_attention(q, k, v)

print(f"{y_mps1-y_mps2 = }")

y_cpu1 = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu())
y_cpu2 = F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu())

print(f"{y_cpu1-y_cpu2 = }")

print(f"{y_mps1.shape = }")
print(f"{y_cpu1.shape = }")

# Output:
# y_mps1-y_mps2 = tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000],
#           [ 0.0000,  0.0000,  0.0000,  0.0000],
#           [-0.0957, -0.1705, -0.1087, -0.1648],
#           [-0.1155, -0.1248, -0.0946, -0.1255],
#           [-0.1089, -0.1662, -0.1261, -0.1559],
#           [-0.0904, -0.1262, -0.0927, -0.1335],
#           [-0.1080, -0.1644, -0.1226, -0.1611],
#           [-0.0849, -0.1324, -0.0984, -0.1281]]]], device='mps:0')
# y_cpu1-y_cpu2 = tensor([[[[0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.],
#           [0., 0.]]]])
# y_mps1.shape = torch.Size([1, 1, 8, 4])
# y_cpu1.shape = torch.Size([1, 1, 8, 2])

The numeric mismatch above is due to uninitialized memory being part of the return tensor. Notice that there is also a shape mismatch from the expected result.

The issue has been fixed in PyTorch (https://github.com/pytorch/pytorch/pull/176843), but won't be released until PyTorch 2.12.

I don't know if different dims for query and value is a common use-case for Hugging Face models, but I wanted to notify you of the issue regardless. I'm happy to provide a workaround if you think this issue should be fixed in transformers. Potential workaround is to resize the output tensor and wrangle the strides, as the calculation itself is good for the first (expected) numels.

The issue is similar to #44247, which is another correctness issue in MPS SDPA needing workaroud for specific PyTorch versions.

Who can help?

@vasqu @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

See reproducer above. I'm not sure if this is an issue in transformers.

Expected behavior

See reproducer above.

extent analysis

Fix Plan

Workaround

As the issue is fixed in PyTorch 2.12, we can use a workaround to resize the output tensor and wrangle the strides. Here's a step-by-step solution:

1. Resize the output tensor

import torch
import torch.nn.functional as F

q = torch.rand(1, 1, 8, 4, device="mps")
k = torch.rand(1, 1, 8, 4, device="mps")
v = torch.rand(1, 1, 8, 2, device="mps")

y_mps1 = F.scaled_dot_product_attention(q, k, v)
y_mps1_resized = y_mps1[:, :, :, :2]  # Resize to match the value head dims

print(f"{y_mps1_resized.shape = }")

2. Wrangle the strides

import torch
import torch.nn.functional as F

q = torch.rand(1, 1, 8, 4, device="mps")
k = torch.rand(1, 1, 8, 4, device="mps")
v = torch.rand(1, 1, 8, 2, device="mps")

y_mps1 = F.scaled_dot_product_attention(q, k, v)
y_mps1_resized = y_mps1[:, :, :, :2]  # Resize to match the value head dims

# Get the strides of the original tensor
original_strides = y_mps1.stride()

# Create a new tensor with the same strides
y_mps1_resized_strided = torch.zeros_like(y_mps1_resized).stride_(original_strides)

print(f"{y_mps1_resized_strided.shape = }")

Verification

To verify that the fix worked, you can

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

See reproducer above.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix [MPS] Upstream correctness issue in attention when value head dim differs from query [1 pull requests, 7 comments, 3 participants]