pytorch - 💡(How to fix) Fix Sub-optimal heuristics in MixOrderReduction prevent fusion

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

MixOrderReduction.can_fuse() has two heuristic gates that are too conservative, preventing profitable fusion at common shapes:

  1. Split reduction takes unconditional priority: is_split_reduction() rejects fusion even when non_strict_mode=True. The split decision is made in IR lowering before the scheduler can consider mix-order fusion.
  2. Strict checks too conservative: In strict mode, can_fuse() enforces three conditions (total size threshold, nrow >= ncol*2, and nrow >= 4096). For the shapes below, the nrow >= ncol*2 ratio is the failing gate, rejecting fusion where it would be profitable.

Root Cause

Two sub-optimal heuristics in MixOrderReduction.can_fuse() (scheduler.py):

Heuristic 1: is_split_reduction() veto

if MixOrderReduction.is_split_reduction(contiguous_node):
    return False

Split reduction is decided during IR lowering without consulting mix-order non-strict mode. By the time the scheduler calls can_fuse(), the node is already marked as split and rejected.

Heuristic 2: nrow >= ncol*2 ratio

if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, ncol * 2)):
    return False

This fixed ratio rejects shapes where fusion is profitable. Note: existing tests (test_dimension_too_close) encode this rationale, so any relaxation needs broader shape sweep data to avoid regressions.

Code Example

python rmsnorm_bwd_mix_order_issue.py -v

---

Shape  strict (us)  SOL%  non-strict (us)  SOL%  Speedup  no-split (us)  SOL%  Speedup
  ──────────────  ─────────── ─────  ─────────────── ─────  ───────  ───────────── ─────  ───────
  [ 4096,  8192]       74.5   34.1%           45.8   55.5%    1.63x       45.9   55.4%    1.62x
  [ 8192,  8192]      167.3   30.4%          167.4   30.3%    1.00x      121.9   41.7%    1.37x

---

if MixOrderReduction.is_split_reduction(contiguous_node):
    return False

---

if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, ncol * 2)):
    return False
RAW_BUFFERClick to expand / collapse

Summary

MixOrderReduction.can_fuse() has two heuristic gates that are too conservative, preventing profitable fusion at common shapes:

  1. Split reduction takes unconditional priority: is_split_reduction() rejects fusion even when non_strict_mode=True. The split decision is made in IR lowering before the scheduler can consider mix-order fusion.
  2. Strict checks too conservative: In strict mode, can_fuse() enforces three conditions (total size threshold, nrow >= ncol*2, and nrow >= 4096). For the shapes below, the nrow >= ncol*2 ratio is the failing gate, rejecting fusion where it would be profitable.

Reproduction

python rmsnorm_bwd_mix_order_issue.py -v

Repro script

Results (NVIDIA GB200, bf16 RMSNorm backward)

           Shape  strict (us)  SOL%  non-strict (us)  SOL%  Speedup  no-split (us)  SOL%  Speedup
  ──────────────  ─────────── ─────  ─────────────── ─────  ───────  ───────────── ─────  ───────
  [ 4096,  8192]       74.5   34.1%           45.8   55.5%    1.63x       45.9   55.4%    1.62x
  [ 8192,  8192]      167.3   30.4%          167.4   30.3%    1.00x      121.9   41.7%    1.37x
  • strict: default behavior (all strict checks enforced)
  • non-strict: mix_order_reduction_non_strict_mode=True (skips the three strict-only checks: total size, nrow >= ncol*2, nrow >= 4096)
  • no-split: non_strict=True + split_reductions=False (bypasses split reduction entirely)

[4096, 8192]: non-strict fuses successfully. 3 kernels → 2 kernels, 1.63x speedup.

[8192, 8192]: non-strict alone has no effect (1.00x). Disabling split reductions suggests fusion is profitable at this shape: 3 kernels → 2 kernels, 1.37x speedup (30% → 42% SOL).

Root Cause

Two sub-optimal heuristics in MixOrderReduction.can_fuse() (scheduler.py):

Heuristic 1: is_split_reduction() veto

if MixOrderReduction.is_split_reduction(contiguous_node):
    return False

Split reduction is decided during IR lowering without consulting mix-order non-strict mode. By the time the scheduler calls can_fuse(), the node is already marked as split and rejected.

Heuristic 2: nrow >= ncol*2 ratio

if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, ncol * 2)):
    return False

This fixed ratio rejects shapes where fusion is profitable. Note: existing tests (test_dimension_too_close) encode this rationale, so any relaxation needs broader shape sweep data to avoid regressions.

Kernel breakdown

Unfused (3 kernels) — reads x and dy from DRAM twice:

  1. Outer reduction for dw partials (reads x, dy)
  2. Final dw sum
  3. Inner reduction for dx (reads x, dy again)

Fused (2 kernels) — reads x and dy once:

  1. Mix-order kernel: inner reduction (dx) + outer reduction partials (dw) in a single pass
  2. Final dw sum

Suggested Fixes

Fix 1: When non-strict mode is enabled, fused reduction should prioritize over split reduction

Currently is_split_reduction() unconditionally rejects fusion even when non_strict_mode=True. The [8192, 8192] no-split result suggests 1.37x opportunity (30% → 42% SOL).

Fix 2: Adjust the strict checks to allow fusion at more shapes

The current nrow >= ncol*2 ratio is too conservative for the shapes tested. [4096, 8192] demonstrates 1.63x speedup when fused despite failing the strict check. A broader shape sweep is needed to determine the right threshold or replacement condition.

Note on benchmark methodology

The benchmark reports the sum of CUDA kernel durations (via torch profiler), assuming kernel launch overhead is negligible in real applications with CUDA graph. L2 cache is flushed between iterations assuming it is polluted by neighboring ops in practice.

Environment

  • PyTorch: 2.13.0a0+git9661ae6 (commit 9661ae6e5416b89d274c2a348889264ac5816d68)
  • Triton: 3.7.0
  • CUDA: 13.2
  • GPU: NVIDIA GB200
  • Peak memory bandwidth: 7928 GB/s

Created with Claude

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING