pytorch - ✅(Solved) Fix [MPS] `scaled_dot_product_attention` (SDPA) improvements [1 pull requests, 4 comments, 4 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179294Fetched 2026-04-08 02:43:37
View on GitHub
Comments
4
Participants
4
Timeline
56
Reactions
0
Assignees
Timeline (top)
mentioned ×19subscribed ×19labeled ×8commented ×4

Fix Action

Fixed

PR fix notes

PR #179309: [MPS] Improve call site of _scaled_dot_product_attention_math_mps

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #179309

Part of #179294

Changed files

  • aten/src/ATen/native/mps/operations/Attention.mm (modified, +32/-3)
  • aten/src/ATen/native/native_functions.yaml (modified, +1/-1)
  • aten/src/ATen/native/transformers/attention.cpp (modified, +3/-43)
  • torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h (modified, +1/-1)
RAW_BUFFERClick to expand / collapse

I've been looking into the current implementation of scaled_dot_product_attention for MPS, and I'd like to lay out some of the problems with it and my plans for improving it. These are not necessarily in the order that they will actually be completed in.

Problem - Messy call site of _scaled_dot_product_attention_math_for_mps

In the place where the MPS impl is called, there is a bunch of extra code that should be moved into the MPS impl: https://github.com/pytorch/pytorch/blob/0cc07dbe82d222bcf97266aa978542e2733814f3/aten/src/ATen/native/transformers/attention.cpp#L795-L825

Invoking the MPS impl should be as simple as calling one function.

Problem - No backward pass

Currently we don't have a dedicated MPS impl of the backward pass. Instead, the backward pass is implemented using implicit composite autograd from the device-agnostic "math" backend function _scaled_dot_product_attention_math. The forward pass has to call _scaled_dot_product_attention_math in order to build the graph for the backward pass. This means that if someone wants to run a backward pass of SDPA for MPS, they can't use the dedicated MPS impl of the forward pass. So not only is the backward pass slower than it could be, but the forward pass is also slower, since both are using a device-agnostic impl.

So a dedicated MPS backward pass needs to be implemented.

Problem - Unused second output tensor

As some issues (#176730, #175873) have mentioned, _scaled_dot_product_attention_math_for_mps returns a second output tensor that is currently not used by anything. #173943 proposes completely removing the code that generates the second output, and just return an empty tensor, to improve performance. But this is not the right solution.

The second output is required for calculating the backward pass. So when we have the backward implemented in the future, we will need this second output.

I propose that we conditionally calculate the second output only if one of the input tensors requires grad. In the case where we don't need to calculate it, we should return a null tensor (at::Tensor()) so we can avoid the overhead of creating an empty tensor.

Problem - Math vs. flash attention

While there isn't really documentation about this, it appears that the "math" backend is supposed to just call the generic device-agnostic _scaled_dot_product_attention_math function. Then if the user wants to use a dedicated GPU implementation, they should select one of the other backends, such as the "flash" backend. This is true on CUDA platforms. But for MPS platforms, the dedicated MPS impl gets called if the "math" backend is selected and the inputs don't require grad and some other odd conditions: https://github.com/pytorch/pytorch/blob/19cda9b00702cd65bc9c46d467f0f5e928991f9f/aten/src/ATen/native/transformers/attention.cpp#L811-L814

If my assumption that the "math" backend is supposed to always call the device-agnostic impl is correct, then we should move the dedicated MPS impl to the "flash" backend.

Problem - Generic Metal implementation is never called

There appears to be a Metal kernel implementation of SDPA in sdpa_full_attention_mps, but it isn't called at all.

Presumably, a Metal kernel is more reliable and potentially faster than the current MPSGraph impl, so this codepath should be revived and fixed if it is broken. The MPSGraph impl can be removed if it is categorically worse.

cc @jerryzh168 @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk @drisspg @liangel-02 @howardzhang-cv @manuelcandales

extent analysis

TL;DR

Refactor the scaled_dot_product_attention implementation for MPS to simplify the call site, implement a dedicated backward pass, and optimize the usage of the second output tensor.

Guidance

  • Move the extra code from the call site of _scaled_dot_product_attention_math_for_mps into the MPS implementation to simplify the invocation.
  • Implement a dedicated MPS backward pass to improve performance and allow for the use of the dedicated MPS forward pass.
  • Conditionally calculate the second output tensor only if one of the input tensors requires grad to avoid unnecessary overhead.
  • Consider moving the dedicated MPS implementation to the "flash" backend to align with the expected behavior of the "math" backend.
  • Investigate and potentially revive the generic Metal implementation of SDPA in sdpa_full_attention_mps to improve reliability and performance.

Example

No specific code example is provided, as the issue requires a broader refactoring and implementation of new functionality.

Notes

The proposed changes aim to improve the performance, reliability, and maintainability of the scaled_dot_product_attention implementation for MPS. However, the implementation details and potential interactions with other components of the PyTorch framework should be carefully considered to ensure correctness and compatibility.

Recommendation

Apply the proposed refactorings and implementations to improve the scaled_dot_product_attention implementation for MPS, starting with simplifying the call site and implementing a dedicated backward pass.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING