pytorch - ✅(Solved) Fix [inductor] Backward pass 9% slower in 2.11 vs 2.9.1 due to over-fusion of rms_norm_backward [1 pull requests, 5 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179423Fetched 2026-04-08 02:51:48
View on GitHub
Comments
5
Participants
3
Timeline
127
Reactions
0
Timeline (top)
mentioned ×49subscribed ×49referenced ×13labeled ×8

The same transformer model compiled with torch.compile(model, dynamic=False, fullgraph=True) has a backward pass that is ~9% slower on PyTorch 2.11 compared to 2.9.1, despite identical forward pass performance.

Root Cause

Root Cause: Inductor Over-Fusion

PR fix notes

PR #179494: [inductor] Fix mix_order_reduction over-fusion via load count check

Description (problem / solution / changelog)

[inductor] Fix mix_order_reduction over-fusion via load count check

Fixes https://github.com/pytorch/pytorch/issues/179423

Problem

FusedMixOrderReductions.sub_node_can_fuse() absorbs additional nodes into mixed-order reduction kernels without checking the resulting load count. This creates Triton kernels with excessive tl.load() calls in the RSPLIT loop, causing register spills and a +6.3ms/step regression on H100.

Model

The regression was found training a small transformer for the Parameter Golf competition. The exact model:

class RMSNorm(nn.Module):
    def forward(self, x):
        return F.rms_norm(x, (x.size(-1),))

class MLP(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        hidden = int(dim * mult)
        self.fc = nn.Linear(dim, hidden, bias=False)
        self.proj = nn.Linear(hidden, dim, bias=False)
    def forward(self, x):
        return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square())

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, num_kv_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = dim // num_heads
        self.c_q = nn.Linear(dim, dim, bias=False)
        self.c_k = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.c_v = nn.Linear(dim, num_kv_heads * self.head_dim, bias=False)
        self.proj = nn.Linear(dim, dim, bias=False)
    def forward(self, x):
        B, T, D = x.shape
        q = self.c_q(x).reshape(B, T, self.num_heads, self.head_dim)
        k = self.c_k(x).reshape(B, T, self.num_kv_heads, self.head_dim)
        v = self.c_v(x).reshape(B, T, self.num_kv_heads, self.head_dim)
        q = F.rms_norm(q, (q.size(-1),))
        k = F.rms_norm(k, (k.size(-1),))
        q = q.transpose(1, 2)
        k = k.transpose(1, 2).repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        v = v.transpose(1, 2).repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.proj(y.transpose(1, 2).reshape(B, T, D))

class Block(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attn_norm = RMSNorm()
        self.mlp_norm = RMSNorm()
        self.attn = Attention(dim)
        self.mlp = MLP(dim)
    def forward(self, x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class Model(nn.Module):
    def __init__(self, vocab_size=4096, dim=512, num_layers=11):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, dim)
        self.blocks = nn.ModuleList([Block(dim) for _ in range(num_layers)])
        self.norm = RMSNorm()
        self.head = nn.Linear(dim, vocab_size, bias=False)
    def forward(self, x, y):
        h = self.tok_emb(x)
        h = F.rms_norm(h, (h.size(-1),))
        for block in self.blocks:
            h = block(h)
        h = self.norm(h)
        logits = self.head(h)
        return F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))

model = Model().cuda().bfloat16()
compiled = torch.compile(model, dynamic=False, fullgraph=True)
x = torch.randint(0, 4096, (32, 2048), device='cuda')
y = torch.randint(0, 4096, (32, 2048), device='cuda')
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    compiled(x, y).backward()

Key properties: dim=512, 11 transformer blocks, GQA attention with QK-norm, squared leaky-relu MLP, bf16 autocast, fullgraph=True.

Root Cause

During the backward pass, each block produces:

  • rms_norm backward: inner reduction over ncol=512, xnumel=98304 (batch×seq = 48×2048)
  • weight gradient sums: outer reduction over xnumel=98304, keeping ncol=512

mix_order_reduction fuses these two reductions (different iteration orders, same data) into a single kernel. Then sub_node_can_fuse absorbs surrounding pointwise ops (residual connections, dtype casts, scaling) without any check on the resulting read count.

The fused kernel uses persistent reduction with R0_BLOCK = ncol = 512 threads per block and an RSPLIT loop that iterates over chunks of the x-dimension. On H100 (65536 regs/SM), 512 threads/block gives 128 regs/thread. That is the register budget.

Each external read buffer becomes a tl.load() inside the RSPLIT loop. Every additional load adds register pressure. The unfused kernel (7 reads) barely fits in 128 regs. The over-fused kernel (11+ reads, plus persistent accumulator arrays) overflows and spills to local memory.

The spill penalty (~100 cycles per access vs 0 for register) is paid every RSPLIT loop iteration (64 iterations per block, 1536 blocks total), producing the 6.3ms regression.

Kernel comparison

2.9.1 — unfused rms_norm backward (kernel_8, Grid1D, 7 loads, 1 reduction):

# No loop, no accumulators, no workspace — one thread block per row
def triton_per_fused__fused_rms_norm_backward_8(
    in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5,
    out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK):
    # xnumel=98304, r0_numel=512, persistent R0_BLOCK=512
    # 7 loads → ~120 regs/thread, fits in 128 budget
    tmp0 = tl.load(in_ptr0 + ...)   # [98304, 512] rsqrt * Hessian
    tmp1 = tl.load(in_out_ptr0 + ...)  # [98304, 512] upstream grad
    tmp9 = tl.load(in_ptr1 + ...)   # [98304, 512] residual
    tmp10 = tl.load(in_ptr2 + ...)  # scalar: mix coefficient
    tmp20 = tl.load(in_ptr3 + ...)  # [98304, 1] rsqrt
    tmp25 = tl.load(in_ptr4 + ...)  # [512] norm weight 1
    tmp28 = tl.load(in_ptr5 + ...)  # [512] norm weight 2
    # ... 1 inner reduction (sum over 512), 3 stores

2.11 — over-fused rms_norm backward + weight grad sums (kernel_3, MixOrderReductionGrid, 11 loads, 3 reductions):

# RSPLIT loop with persistent accumulators and workspace memory
def triton_per_fused__fused_rms_norm_backward__to_copy_..._3(
    in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4,
    in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9,
    out_ptr0, out_ptr1, out_ptr3, out_ptr4, ws_ptr,
    xnumel, r0_numel, XBLOCK, RSPLIT_SIZE, NUM_STAGES):
    # 11 loads + 2 accumulators → ~200 regs/thread, EXCEEDS 128 budget
    accum0 = tl.full([R0_BLOCK], 0, tl.float32)  # [512] persists across loop
    accum1 = tl.full([R0_BLOCK], 0, tl.float32)  # [512] persists across loop
    for _ in tl.range(0, split_size, XBLOCK):
        tmp0 = tl.load(in_ptr0 + ...)     # [98304, 512]
        tmp1 = tl.load(in_ptr1 + ...)     # [98304, 512]
        tmp7 = tl.load(in_ptr2 + ...)     # [98304, 512]
        tmp13 = tl.load(in_ptr3 + ...)    # [98304, 512]
        tmp14 = tl.load(in_out_ptr0 + ...)# [98304, 512]
        tmp22 = tl.load(in_ptr4 + ...)    # scalar
        tmp32 = tl.load(in_ptr5 + ...)    # [98304, 1]
        tmp37 = tl.load(in_ptr6 + ...)    # [512]
        tmp40 = tl.load(in_ptr7 + ...)    # [512]
        tmp43 = tl.load(in_ptr8 + ...)    # [98304, 512]
        tmp46 = tl.load(in_ptr9 + ...)    # [98304, 512]
        # 3 inner reductions + 5 stores + 2 accumulator updates per iter
        # Spilled regs hit local memory EVERY iteration
    tl.store(ws_ptr + ..., accum0, ...)  # workspace for inter-block reduction
    tl.store(ws_ptr + ..., accum1, ...)

This same over-fusion pattern repeats across the backward pass, producing 9 MixOrderReductionGrid kernels with 6-19 loads each. The worst (kernel_34) has 19 loads.

Profiler data

H100 80GB SXM, torch.profiler:

ConfigTriton kernelsSelf CUDA TimeDelta
2.11, mix_order=1 (default)65105.764ms+6.3ms
2.11, mix_order=07199.471msbaseline
2.9.1 (no mix_order)7199.5msbaseline

Fix

Count unique read buffers across all subnodes in FusedMixOrderReductions.can_fuse_with. If the count exceeds mix_order_reduction_max_reads (default 10), reject the fusion:

all_reads = {dep.name for all subnodes' reads if MemoryDep}
if len(all_reads) > max_reads:
    return False

Uses all_reads rather than all_reads - all_writes because mutated buffers (in_out_ptr) are both read and written — they are still tl.load() calls. Each unique read maps 1:1 to a tl.load() in the generated RSPLIT loop. The check runs at scheduling time with zero compilation cost — it just counts buffer names from the existing read_writes dependency data.

Test Plan

  • OverFusionTest.test_max_reads_limits_fusion — 3-block transformer backward, verifies correctness and that mix_order_reduction still fires (not fully disabled)
  • Existing MixOrderReductionTest and NoMixOrderReductionTest suites unaffected
  • Verified on H100: step time with this fix matches 2.9.1 / mix_order=0 baseline

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_mix_order_reduction.py (modified, +96/-0)
  • torch/_inductor/config.py (modified, +4/-0)
  • torch/_inductor/metrics.py (modified, +3/-0)
  • torch/_inductor/scheduler.py (modified, +15/-0)
RAW_BUFFERClick to expand / collapse

Summary

The same transformer model compiled with torch.compile(model, dynamic=False, fullgraph=True) has a backward pass that is ~9% slower on PyTorch 2.11 compared to 2.9.1, despite identical forward pass performance.

Profiler Data

ComponentPyTorch 2.9.1PyTorch 2.11.0Delta
Backward compiled graph67.28ms73.21ms+5.93ms (+8.8%)
Forward compiled graph34.40ms34.47ms+0.07ms
aten::mm33.17ms33.13msidentical
FA3 backward20.13ms20.11msidentical

Root Cause: Inductor Over-Fusion

Inductor generates fewer but larger fused Triton kernels in 2.11:

PyTorch 2.9.1PyTorch 2.11.0
Triton kernel functions7165
Largest backward kernel11,292 lines11,855 lines

Key difference: 2.11 fuses _fused_rms_norm_backward into adjacent kernels. 2.9.1 keeps them separate. The larger fused kernels run slower.

Isolation

  • Not Triton: Swapping Triton 3.5.1 into PyTorch 2.11 has no effect
  • Not autocast: Gap persists without autocast
  • Not cuDNN/cuBLAS: Forcing backends has no effect
  • Forward is identical: Only the backward compiled graph is slower

Environment

  • GPU: NVIDIA H100 80GB HBM3 SXM, Driver 570.148.08, CUDA 12.8
  • Model: 34.4M param transformer, 11 layers, d=512, RMSNorm, depth recurrence, parallel residuals
  • max_fusion_size=64 and aggressive_fusion=False (defaults) - over-fusion happens within these limits

Impact

For time-budgeted training (600s), costs ~57 training steps (~1% of total).

cc @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The most likely fix for the 9% slower backward pass in PyTorch 2.11 is to adjust the fusion settings to prevent over-fusion of Triton kernels.

Guidance

  • Investigate the effect of adjusting max_fusion_size and aggressive_fusion settings on the backward pass performance.
  • Verify that the issue is indeed caused by the over-fusion of Triton kernels by comparing the performance with different fusion settings.
  • Consider using a smaller max_fusion_size to prevent the creation of large fused kernels that may run slower.
  • Test the performance with aggressive_fusion=True to see if it improves the backward pass speed.

Example

No code snippet is provided as the issue does not imply a specific code change.

Notes

The issue seems to be specific to the combination of PyTorch 2.11, the transformer model, and the GPU architecture. The fix may not apply to other models or environments.

Recommendation

Apply a workaround by adjusting the fusion settings, as the root cause is identified as inductor over-fusion, and adjusting these settings may help prevent it.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [inductor] Backward pass 9% slower in 2.11 vs 2.9.1 due to over-fusion of rms_norm_backward [1 pull requests, 5 comments, 3 participants]