transformers - 💡(How to fix) Fix Proposal: add sdpa_memeff attn_implementation for shape combinations no fast backend covers [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45636Fetched 2026-04-25 06:03:04
View on GitHub
Comments
0
Participants
1
Timeline
0
Reactions
0
Participants

Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.

Root Cause

Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.

Fix Action

Fix / Workaround

Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.

pytorch/pytorch#44928 — Qwen3.5 RLHF training sees dense 4D mask materialization (from 3D position_ids), which disqualifies FLASH's is_causal=True short-circuit. Dispatcher falls through to MATH and triggers NaN gradients in bf16.

Same root shape of problem as (1), driven by layout rather than head_dim. A manual with sdpa_kernel([EFFICIENT_ATTENTION]): ... workaround exists but requires model-code edits; process-global torch.backends.cuda.enable_math_sdp(False) has side effects.

RAW_BUFFERClick to expand / collapse

Summary

Proposal to add a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION (via sdpa_kernel([EFFICIENT_ATTENTION]) wrapping the existing sdpa_attention_forward). Filing as an issue to validate design direction before opening a PR.

Motivation — two independent failure modes that attn_implementation="sdpa" doesn't handle well

1. head_dim > 256

On every CUDA arch I've tested (RTX 5090 sm_120; same library caps apply on H100 sm_90):

Backendhead_dim=320 (Gemma 4)
FLASHREJECT (library cap: head_dim ≤ 256)
CUDNNREJECT (library cap: head_dim ≤ 256)
EFFICIENT (CUTLASS mem-eff)ACCEPT
MATHACCEPT (but O(seq²) fp32 softmax — 74 GB at seq=32k, OOMs on 96 GB card)

So for Gemma 4 (google/gemma-4-31B, gemma-4-26B-A4B, gemma-4-E4B, gemma-4-E2B — all head_dim=320) there's only one fast-path backend, and stock sdpa has no way to pin it reliably.

2. Input layouts that disqualify FLASH at any head_dim

pytorch/pytorch#44928 — Qwen3.5 RLHF training sees dense 4D mask materialization (from 3D position_ids), which disqualifies FLASH's is_causal=True short-circuit. Dispatcher falls through to MATH and triggers NaN gradients in bf16.

Same root shape of problem as (1), driven by layout rather than head_dim. A manual with sdpa_kernel([EFFICIENT_ATTENTION]): ... workaround exists but requires model-code edits; process-global torch.backends.cuda.enable_math_sdp(False) has side effects.

Proposed fix

New file src/transformers/integrations/sdpa_memeff.py. Registers "sdpa_memeff" in ALL_ATTENTION_FUNCTIONS. Copies the existing sdpa_attention_forward with two additions:

  1. Wraps the actual SDPA call in with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): ...
  2. Unconditional GQA repeat_kv (EFFICIENT_ATTENTION rejects dense GQA where num_heads_q != num_heads_kv).

Verified locally on full 8B Gemma-4-E4B-it (bf16, real checkpoint): 100% top-1 agreement vs stock sdpa where stock works, and correct execution where stock falls back to MATH and OOMs.

Questions for maintainers

  1. Shape of the contribution. Dedicated "sdpa_memeff" name vs. a config knob on sdpa (e.g. sdpa_preferred_backends=["EFFICIENT"]) vs. the pipe syntax introduced by PR #39823 ("sdpa|efficient_attention")?

  2. Overlap with pytorch-side fixes. pytorch's SDPA dispatcher is missing sm_120 support for CUDNN head_dim=256 (filed separately). Once that lands, CUDNN handles head_dim ≤ 256 on sm_120 — but head_dim=320 (Gemma 4) is still library-capped on both FLASH and CUDNN, and #44928-class layout-induced FLASH disqualification is unaffected. So sdpa_memeff isn't made redundant.

  3. Test coverage you'd want. Convergence (loss-curve equivalence with stock sdpa on a small model)? Numerical (per-batch logit diff)? Both?

  4. Anything first-time-contributor I should know — CLA, preferred test location, maintainer norms for adding to ALL_ATTENTION_FUNCTIONS, etc.?

Happy to open a PR if this direction is right.


🤖 Drafted with Claude Code (Claude Opus 4.7), reviewed and posted by me.

extent analysis

TL;DR

The proposed fix involves adding a new attn_implementation="sdpa_memeff" that pins torch's SDPA dispatcher to SDPBackend.EFFICIENT_ATTENTION to handle cases where attn_implementation="sdpa" fails.

Guidance

  • To address the issue, create a new file src/transformers/integrations/sdpa_memeff.py and register "sdpa_memeff" in ALL_ATTENTION_FUNCTIONS.
  • Wrap the actual SDPA call in with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]): ... to ensure the efficient attention backend is used.
  • Add unconditional GQA repeat_kv to handle cases where num_heads_q != num_heads_kv.
  • Verify the fix by testing convergence and numerical equivalence with the stock sdpa implementation on a small model.

Example

No code snippet is provided as the issue does not contain a specific code example that can be used to illustrate the fix.

Notes

The proposed fix is not made redundant by pytorch-side fixes, as it addresses specific cases where attn_implementation="sdpa" fails, such as head_dim > 256 and input layouts that disqualify FLASH.

Recommendation

Apply the proposed workaround by adding the new attn_implementation="sdpa_memeff" and registering it in ALL_ATTENTION_FUNCTIONS, as it provides a reliable way to pin the SDPA dispatcher to the efficient attention backend.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING