pytorch - 💡(How to fix) Fix `flex_attention` backward is 10-30× slower than SDPA backward under `torch.compile(backend="aot_eager")` (math fallback materializes O(S²) and runs 114 ATen ops/layer) [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182611Fetched 2026-05-07 03:31:11
View on GitHub
Comments
0
Participants
1
Timeline
45
Reactions
0
Author
Participants
Timeline (top)
mentioned ×19subscribed ×19labeled ×7

Fix Action

Fix / Workaround

Under torch.compile(backend="aot_eager"), flex_attention backward lowers to the sdpa_dense_backward math fallback in torch/_higher_order_ops/flex_attention.py. That fallback dispatches 114 ATen ops per layer vs SDPA's 13 — an 8.8× ratio that is invariant to seqlen and head_dim — and runs an unfused, full-[B, H, S, S]-materialization computation. On GB200 this combination produces a backward bubble (flex bwd minus sdpa bwd, same model shape) of:

  • Verified on torch 2.13.0a0+gitdbeb1a8 (2026-05-05 nightly).
  • Hardware for wall-clock: 1× GB200 sm_100, driver 580.65.06. Dispatch count is hardware-independent.

Minimal repro for the dispatch-count claim (CPU only)

Code Example

from collections import Counter
import torch, torch.nn.functional as F
from torch._higher_order_ops.flex_attention import sdpa_dense_backward
from torch.utils._python_dispatch import TorchDispatchMode

class Count(TorchDispatchMode):
    def __init__(self): self.n = 0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.n += 1
        return func(*args, **(kwargs or {}))

B, H, S, D = 2, 4, 128, 16
torch.manual_seed(0)
q, k, v = (torch.randn(B, H, S, D) for _ in range(3))
out = F.scaled_dot_product_attention(q, k, v)
lse = torch.zeros(B, H, S)
go = torch.randn_like(out)
score_mod = lambda s, b, h, m, n: s
mask_mod  = lambda b, h, q, kv: q >= 0
block_mask = (
    torch.ones(1,1,1,    dtype=torch.int32),
    torch.ones(1,1,1,1,  dtype=torch.int32),
    torch.ones(1,1,1,    dtype=torch.int32),
    torch.ones(1,1,1,1,  dtype=torch.int32),
    None, None, None, None, S, S, mask_mod,
)
opts = dict(PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False,
            BLOCKS_ARE_CONTIGUOUS=False, WRITE_DQ=True,
            OUTPUT_LOGSUMEXP=False, OUTPUT_MAX=False)

flex = Count()
with flex:
    sdpa_dense_backward(q, k, v, out, lse, go, None,
                        score_mod, None, block_mask,
                        1.0 / D**0.5, opts, (), ())

sdpa = Count()
q2, k2, v2 = (t.clone().requires_grad_(True) for t in (q, k, v))
with sdpa:
    F.scaled_dot_product_attention(q2, k2, v2).sum().backward()

print(f"flex_dense_bwd: {flex.n}")
print(f"sdpa_bwd:       {sdpa.n}")
print(f"ratio:          {flex.n/sdpa.n:.1f}x")

---

flex_dense_bwd: 114
sdpa_bwd:       13
ratio:          8.8x

---

expand × 2  +  view × 2  +  transpose × 1  +  bmm × 1  +  _unsafe_view × 1   = 7 dispatches

---

torch.compile(model, backend="aot_eager")
# + annotate flex_attention HOPs with compile_with_inductor (regional_inductor pass)
RAW_BUFFERClick to expand / collapse

Under torch.compile(backend="aot_eager"), flex_attention backward lowers to the sdpa_dense_backward math fallback in torch/_higher_order_ops/flex_attention.py. That fallback dispatches 114 ATen ops per layer vs SDPA's 13 — an 8.8× ratio that is invariant to seqlen and head_dim — and runs an unfused, full-[B, H, S, S]-materialization computation. On GB200 this combination produces a backward bubble (flex bwd minus sdpa bwd, same model shape) of:

  • +109 ms / step at S=4096, head_dim=128, n_layers=16 (DSV3-class bf16 training shape)
  • +217 ms / step at the same shape with n_layers=27 (DSV3 16B layer count)
  • +298 ms / step with n_layers=43 (V4-Flash layer count)
  • +410 ms / step at S=8192, head_dim=128, n_layers=16 (bubble grows superlinearly in S)

The bubble is aot_eager-specific. At the same S=4096, head_dim=128, n_layers=16 shape, inductor lowers the math fallback into a single fused Triton kernel and the bubble drops from +109 ms to +7 ms (flex bwd 117 ms → 12 ms).

Symptom in the wild

pytorch/torchtitan#2089 (filed 2025-11-26 by @ruisizhang123) is the original report from a real DSV3 16B + SimpleFSDP + EP training stack run with --compile.backend "aot_eager". The screenshots in that issue make the bubble visible directly:

  • With FlexAttention: a sparse band of CPU-bound mutation ops (aten.empty_strided, aten.add) during backward and a corresponding GPU-idle gap.
  • Without FlexAttention (same stack, attention removed): the bubble disappears and backward execution is dense.

Versions

  • Verified on torch 2.13.0a0+gitdbeb1a8 (2026-05-05 nightly).
  • Hardware for wall-clock: 1× GB200 sm_100, driver 580.65.06. Dispatch count is hardware-independent.

Minimal repro for the dispatch-count claim (CPU only)

from collections import Counter
import torch, torch.nn.functional as F
from torch._higher_order_ops.flex_attention import sdpa_dense_backward
from torch.utils._python_dispatch import TorchDispatchMode

class Count(TorchDispatchMode):
    def __init__(self): self.n = 0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.n += 1
        return func(*args, **(kwargs or {}))

B, H, S, D = 2, 4, 128, 16
torch.manual_seed(0)
q, k, v = (torch.randn(B, H, S, D) for _ in range(3))
out = F.scaled_dot_product_attention(q, k, v)
lse = torch.zeros(B, H, S)
go = torch.randn_like(out)
score_mod = lambda s, b, h, m, n: s
mask_mod  = lambda b, h, q, kv: q >= 0
block_mask = (
    torch.ones(1,1,1,    dtype=torch.int32),
    torch.ones(1,1,1,1,  dtype=torch.int32),
    torch.ones(1,1,1,    dtype=torch.int32),
    torch.ones(1,1,1,1,  dtype=torch.int32),
    None, None, None, None, S, S, mask_mod,
)
opts = dict(PRESCALE_QK=False, ROWS_GUARANTEED_SAFE=False,
            BLOCKS_ARE_CONTIGUOUS=False, WRITE_DQ=True,
            OUTPUT_LOGSUMEXP=False, OUTPUT_MAX=False)

flex = Count()
with flex:
    sdpa_dense_backward(q, k, v, out, lse, go, None,
                        score_mod, None, block_mask,
                        1.0 / D**0.5, opts, (), ())

sdpa = Count()
q2, k2, v2 = (t.clone().requires_grad_(True) for t in (q, k, v))
with sdpa:
    F.scaled_dot_product_attention(q2, k2, v2).sum().backward()

print(f"flex_dense_bwd: {flex.n}")
print(f"sdpa_bwd:       {sdpa.n}")
print(f"ratio:          {flex.n/sdpa.n:.1f}x")

Output (deterministic, shape-invariant):

flex_dense_bwd: 114
sdpa_bwd:       13
ratio:          8.8x

Verified at S ∈ {128, 512, 1024, 2048, 4096} × head_dim ∈ {16, 64, 128} — all produce the same 114 / 13.

Wall-clock impact

GB200 sm_100, bf16, B=2, H=8, backend=aot_eager. Wall-clock measured with cuda.Event over 5 warmup + 20 timed iterations. Table reports flex_fwd + flex_bwd separately so the bubble can be read against full step time at each shape.

At a production-typical shape (S=4096, head_dim=128, H=8)

n_layersflex fwdflex bwdsdpa bwdbubble (flex_bwd − sdpa_bwd)bubble / flex step
1689 ms116 ms6 ms+109 ms53%
27 (DSV3 16B)147 ms227 ms11 ms+217 ms58%
43 (V4-Flash)232 ms316 ms18 ms+298 ms54%

"bubble / flex step" is bubble divided by flex_fwd + flex_bwd. Roughly half of flex_attention's entire training-step time at these shapes is the unnecessary bwd overhead vs. SDPA. Production training adds MLP / norm / projection work on top; the bubble's share of total training step time is < 60% (likely 10-30% for a complete transformer step, depending on the rest of the layer).

Shape sensitivity at fixed n_layers=16, aot_eager

Shapeflex bwdsdpa bwdbubble
S=1024 H=8 hd=6452 ms6 ms+46 ms
S=4096 H=8 hd=64113 ms7 ms+107 ms
S=4096 H=8 hd=128117 ms8 ms+109 ms
S=8192 H=8 hd=128423 ms14 ms+410 ms

The bubble grows superlinearly in S — going from S=4096 to S=8192 multiplies it by ~3.8× at the same n_layers. This is not shape-invariant CPU dispatch overhead alone; the math fallback's GPU work also scales as O(S²) and is unfused. (See "Why is the bubble so big?" below.)

Inductor reference (S=4096, H=8, head_dim=128, n_layers=16)

backendflex bwdsdpa bwdbubble
aot_eager117 ms6 ms+109 ms
inductor12 ms5 ms+7 ms

Inductor lowers the math fallback into a single fused Triton kernel and beats aot_eager by 10×. The bubble's compute structure is solvable; it just isn't being solved on the aot_eager path.

Why is the bubble so big?

Two compounding causes:

  1. CPU dispatch overhead. sdpa_dense_backward runs 114 ATen ops/layer (vs SDPA's 13). At small / debugging shapes each op's GPU work is small and CPU dispatch latency dominates. Shape-invariant — sets a floor on the bubble.

  2. Unfused O(S²) GPU work at production shapes. sdpa_dense_backward does five 4-D matmuls per layer (1 in _math_attention_inner's forward recompute, 4 in the bwd computing grad_value / grad_softmax_scores / grad_query / grad_key). Each materializes the full [B, H, S, S] attention matrix; then a separate softmax / _vmap_for_bhqkv-vectorized score_mod / mask_mod pass runs over it. SDPA's bwd is one fused FlashAttention-style kernel that never materializes [B, H, S, S]. This is what scales the bubble superlinearly in S.

Cause (1) dominates at S ≤ 1024; cause (2) dominates at S ≥ 4096. Both contribute at every shape.

For (1), each operator-form @ on 4-D tensors decomposes (in C++ via aten.matmul) into:

expand × 2  +  view × 2  +  transpose × 1  +  bmm × 1  +  _unsafe_view × 1   = 7 dispatches

A direct (B, H, M, K) → (B*H, M, K) → bmm → (B, H, M, N) rewrite is 5 dispatches (no expand, no _unsafe_view; one extra reshape). Net: 2 dispatches saved per matmul × 5 matmul sites = 10 dispatches/layer.

The other ~104 dispatches/layer come from softmax recompute, _vmap_for_bhqkv broadcasting for score_mod / mask_mod, gradient accumulation aten.add_ / aten.zeros_like / aten.empty_strided, and GQA repeat_interleave / expand. Each small individually; together they make up the bulk of the 8.8× count.

Possible directions

The two causes need different responses, and only the second matters at production scale:

  • Cause (1) (CPU dispatch overhead) is mechanically reducible — e.g. replacing the five 4-D @ sites in sdpa_dense_backward with direct bmm-on-reshaped-3-D saves 2 dispatches per matmul × 5 sites = 10 dispatches/layer (verified bit-identical at fp32). But that's ~2 ms/step at n_layers=43, < 1% of the production-shape bubble. Worth picking up only as part of a broader rewrite, not in isolation.

  • Cause (2) (unfused O(S²) GPU work) is the headline. Two paths plausibly close it:

    • Skip forward softmax recompute in _math_attention_inner by plumbing attn_lse (and possibly attn_max) from the forward HOP through to the backward, eliminating the recompute pass entirely. Activation-memory cost is substantial at training seqlens (saving full [B, Hq, M, KV] softmax intermediates would add ~100s of GB on a DSV3-16B-class model) — needs a design discussion before any opt-in flag.
    • Fuse the math fallback into a single kernel for non-inductor backends — essentially make the inductor lowering of flex bwd the default. This is what actually eliminates [B, H, S, S] materialization and therefore the O(S²) scaling.

Workaround for affected users today

For training stacks that compile under aot_eager (e.g., torchtitan's experiments/graph_trainer mode), use regional_inductor to compile FlexAttention HOPs into fused Triton kernels while keeping the bulk of the graph under aot_eager:

torch.compile(model, backend="aot_eager")
# + annotate flex_attention HOPs with compile_with_inductor (regional_inductor pass)

torchtitan moved to this default in pytorch/torchtitan#2869 (2026-04-13). At the production shapes above, this brings the bubble from +109 ms / step to +7 ms / step.

Cross-references

  • pytorch/torchtitan#2089 — the symptom report linked above. Worked-around in torchtitan via #2869 (default regional_inductor for FlexAttention) + #3132; this issue covers the underlying core-PyTorch overhead so any future user of aot_eager + flex_attention doesn't have to rediscover it.

cc @jerryzh168 @chauhang @penguinwu @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING