pytorch - ✅(Solved) Fix [MPS] Meta kernel for _scaled_dot_product_attention_math_for_mps guards on dynamic seq_len, breaking torch.export [2 pull requests, 2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177603Fetched 2026-04-08 00:47:16
View on GitHub
Comments
2
Participants
3
Timeline
32
Reactions
2
Timeline (top)
mentioned ×10subscribed ×10referenced ×4labeled ×3

Error Message

torch._dynamo.exc.UserError: Constraints violated (seq)!

  • Not all values of seq = L['q'].size()[2] in the specified range seq <= 64 satisfy the generated guard 2 <= L['q'].size()[2] and L['q'].size()[2] <= 8 Suggested fixes: seq = Dim('seq', max=8)

Root Cause

In torch/_meta_registrations.py, meta__scaled_dot_product_attention_math_for_mps has:

query_seq_len = q_.size(2)
supports_sdpa_vector = (
    (query_seq_len <= 8)  # <-- data-dependent guard on dynamic dim
    and ...
)

The three code paths (sdpa_general_mps, sdpa_vector_2pass_mps, sdpa_vector_fast_mps) return different output shapes. During tracing, the branch on query_seq_len specializes, creating guards that reject all other seq_len values.

Fix Action

Fix / Workaround

Always return sdpa_general_mps shapes from the meta kernel — (B, H, q_size, value_head_size) for output, (B, H, q_size, max_seq_length) for attention. The runtime kernel still dispatches to the optimal implementation; the meta kernel only needs correct shapes for tracing.

TestSDPAMetaDispatchMode compares both return values against the real kernel and will need updating for the 2pass case: the 2pass runtime path returns a 5D intermediate (B,H,q,32,D) vs the meta kernel's 4D (B,H,q,max_seq_length). Since the second return value is universally discarded and already inconsistent across paths (general→4D attn weights, 2pass→5D intermediate, full_attention→duplicate of output), the test should only compare the first return value.

PR fix notes

PR #177620: [MPS] Fix SDPA meta shapes to avoid dynamic seq guards

Description (problem / solution / changelog)

Fix proposed by @mergennachin in #177603. The issue was introduced in #176843.

Remove data-dependent branching in the MPS SDPA meta kernel so export supports dynamic seq.

Update meta-dispatch test to compare only the first output and add an export regression test.

@angelayi, you wrote the original meta registration and tests in #159695. Does this LGTY?

Fixes #177603

Changed files

  • test/test_mps.py (modified, +37/-2)
  • torch/_meta_registrations.py (modified, +5/-50)

PR #177686: [MPS] expand the current export unit test for SDPA

Description (problem / solution / changelog)

Summary

Move the SDPA dynamic seq len export test from test_mps.py to test_aot_inductor.py so it runs on all devices (cpu, gpu, mps), using F.scaled_dot_product_attention instead of the MPS-specific op.

Also clean up the meta registration for _scaled_dot_product_attention_math_for_mps: inline the single-use sdpa_general_mps() helper and remove the dead ensure_4d(key) call, both left over from the branching logic removed in #177620.

Regression test for #177603.

Test plan

  • python test/inductor/test_aot_inductor.py -k test_sdpa_dynamic_seq_len

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @hvaara

Changed files

  • test/inductor/test_aot_inductor.py (modified, +30/-0)
  • test/test_mps.py (modified, +0/-33)
  • torch/_meta_registrations.py (modified, +12/-18)

Code Example

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("k", torch.zeros(1, 4, 64, 128))
        self.register_buffer("v", torch.zeros(1, 4, 64, 128))

    def forward(self, q, mask):
        out, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps(
            q, self.k, self.v, mask, 0.0, False, None
        )
        return out

model = M().to(torch.bfloat16).eval()
seq = Dim("seq", min=1, max=64)

q = torch.randn(1, 4, 4, 128, dtype=torch.bfloat16)
mask = torch.zeros(4, 64, dtype=torch.bfloat16)

# Fails: guard on seq_len <= 8 specializes the trace
ep = export(model, (q, mask), dynamic_shapes={"q": {2: seq}, "mask": {0: seq}}, strict=True)

---

torch._dynamo.exc.UserError: Constraints violated (seq)!
  - Not all values of seq = L['q'].size()[2] in the specified range seq <= 64
    satisfy the generated guard 2 <= L['q'].size()[2] and L['q'].size()[2] <= 8
Suggested fixes:
  seq = Dim('seq', max=8)

---

query_seq_len = q_.size(2)
supports_sdpa_vector = (
    (query_seq_len <= 8)  # <-- data-dependent guard on dynamic dim
    and ...
)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

The meta kernel for aten._scaled_dot_product_attention_math_for_mps branches on query_seq_len <= 8, which creates a data-dependent guard that prevents torch.export(strict=True) from supporting dynamic sequence lengths.

Introduced in #176843 (eccc7be).

Minimal reproduction

import torch
from torch.export import Dim, export

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("k", torch.zeros(1, 4, 64, 128))
        self.register_buffer("v", torch.zeros(1, 4, 64, 128))

    def forward(self, q, mask):
        out, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps(
            q, self.k, self.v, mask, 0.0, False, None
        )
        return out

model = M().to(torch.bfloat16).eval()
seq = Dim("seq", min=1, max=64)

q = torch.randn(1, 4, 4, 128, dtype=torch.bfloat16)
mask = torch.zeros(4, 64, dtype=torch.bfloat16)

# Fails: guard on seq_len <= 8 specializes the trace
ep = export(model, (q, mask), dynamic_shapes={"q": {2: seq}, "mask": {0: seq}}, strict=True)

Error:

torch._dynamo.exc.UserError: Constraints violated (seq)!
  - Not all values of seq = L['q'].size()[2] in the specified range seq <= 64
    satisfy the generated guard 2 <= L['q'].size()[2] and L['q'].size()[2] <= 8
Suggested fixes:
  seq = Dim('seq', max=8)

Expected behavior

Export succeeds with seq as a fully dynamic dimension covering [1, 64].

Root cause

In torch/_meta_registrations.py, meta__scaled_dot_product_attention_math_for_mps has:

query_seq_len = q_.size(2)
supports_sdpa_vector = (
    (query_seq_len <= 8)  # <-- data-dependent guard on dynamic dim
    and ...
)

The three code paths (sdpa_general_mps, sdpa_vector_2pass_mps, sdpa_vector_fast_mps) return different output shapes. During tracing, the branch on query_seq_len specializes, creating guards that reject all other seq_len values.

Proposed fix

Always return sdpa_general_mps shapes from the meta kernel — (B, H, q_size, value_head_size) for output, (B, H, q_size, max_seq_length) for attention. The runtime kernel still dispatches to the optimal implementation; the meta kernel only needs correct shapes for tracing.

This is safe for the first return value (the actual output) because the vector paths require query_head_dim == value_head_dim, so value_head_size is always correct. The second return value is never consumed — all C++ callsites use std::get<0>(...).

TestSDPAMetaDispatchMode compares both return values against the real kernel and will need updating for the 2pass case: the 2pass runtime path returns a 5D intermediate (B,H,q,32,D) vs the meta kernel's 4D (B,H,q,max_seq_length). Since the second return value is universally discarded and already inconsistent across paths (general→4D attn weights, 2pass→5D intermediate, full_attention→duplicate of output), the test should only compare the first return value.

Versions

Regression bisected to eccc7be10726ebabb8f064e8e9990984a83226ec (#176843).

Tested on torch 2.12.0.dev20260312.

cc @chauhang @penguinwu @avikchaudhuri @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @malfet @hvaara

extent analysis

Fix Plan

To fix the issue, we need to modify the meta__scaled_dot_product_attention_math_for_mps function in torch/_meta_registrations.py to always return the sdpa_general_mps shapes.

Here are the steps:

  • Update the meta__scaled_dot_product_attention_math_for_mps function to return the sdpa_general_mps shapes for both return values.
  • Update the TestSDPAMetaDispatchMode test to only compare the first return value.

Code Changes

# torch/_meta_registrations.py
def meta__scaled_dot_product_attention_math_for_mps(
    q, k, v, mask, dropout_p, causal, need_weights
):
    # ...
    query_seq_len = q.size(2)
    # Always return sdpa_general_mps shapes
    output_shape = (q.size(0), q.size(1), q.size(2), v.size(3))
    attention_shape = (q.size(0), q.size(1), q.size(2), k.size(2))
    return output_shape, attention_shape
# test/test_sdpa_meta_dispatch_mode.py
class TestSDPAMetaDispatchMode(TestCase):
    # ...
    def test_sdpa_meta_dispatch_mode(self):
        # ...
        # Only compare the first return value
        self.assertEqual(meta_output[0], runtime_output[0])

Verification

To verify the fix, run the TestSDPAMetaDispatchMode test and ensure that it passes. Additionally, test the torch.export function with dynamic sequence lengths to ensure that it succeeds.

# test/test_export.py
class TestExport(TestCase):
    # ...
    def test_export_dynamic_sequence_length(self):
        model = M().to(torch.bfloat16).eval()
        seq = Dim("seq", min=1, max=64)
        q = torch.randn(1, 4, 4, 128, dtype=torch.bfloat16)
        mask = torch.zeros(4, 64, dtype=torch.bfloat16)
        ep = export(model, (q, mask), dynamic_shapes={"q": {2: seq}, "mask": {0: seq}}, strict=True)
        self.assertIsNotNone(ep)

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Export succeeds with seq as a fully dynamic dimension covering [1, 64].

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS] Meta kernel for _scaled_dot_product_attention_math_for_mps guards on dynamic seq_len, breaking torch.export [2 pull requests, 2 comments, 3 participants]