pytorch - ✅(Solved) Fix [RFC] Optimize persistent reduction with `recompute from cached inputs` [1 pull requests, 2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179711Fetched 2026-04-09 07:50:21
View on GitHub
Comments
2
Participants
3
Timeline
6
Reactions
0
Author
Timeline (top)
commented ×2mentioned ×2subscribed ×2

PR fix notes

PR #179941: [inductor] Add reduction loop peeling for Triton codegen

Description (problem / solution / changelog)

Add loop peeling optimization for non-persistent reductions that splits the reduction loop into an unmasked main loop (vectorizable) and a masked tail loop, gated behind config.triton.loop_peeling (default off).

During the single masked codegen pass, load() and reduction() register masked_line→unmasked_line mappings in a dictionary. The peeled loop emitter uses this map to derive unmasked lines without a second codegen pass, making the optimization zero-cost at compile time.

Fix issue https://github.com/pytorch/pytorch/issues/148402 Idea was proposed by @shunting314

Result: Performance of bf16 softmax (32768, 50257) on GB200: (a) 1.42x speedup for static shape (b) 1.22x speedup for dynamic shape (no vectorization due to lack of stride divisible by 16 hint)

See updated design doc for details.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_codegen_triton.py (modified, +100/-0)
  • torch/_inductor/codegen/simd.py (modified, +9/-0)
  • torch/_inductor/codegen/triton.py (modified, +272/-67)
  • torch/_inductor/config.py (modified, +4/-0)

Code Example

@triton.jit
def rmsnorm_2tile_kernel(x_ptr, w_ptr, y_ptr, M, stride_x, stride_y,
                         EPS: tl.constexpr, TILE: tl.constexpr, BLOCK_M: tl.constexpr):
    pid = tl.program_id(0)
    rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]  # [BLOCK_M, 1]
    mask = rows < M
    offs = tl.arange(0, TILE)[None, :]  # [1, TILE]
    base = x_ptr + rows * stride_x
    x0 = tl.load(base + offs, mask=mask, other=0.0)
    x1 = tl.load(base + TILE + offs, mask=mask, other=0.0)
    ss = x0.to(tl.float32) * x0.to(tl.float32)
    ss += x1.to(tl.float32) * x1.to(tl.float32)
    inv_rms = tl.rsqrt(tl.sum(ss, axis=1)[:, None] / (2 * TILE) + EPS)
    ybase = y_ptr + rows * stride_y
    w0 = tl.load(w_ptr + offs)
    tl.store(ybase + offs, (x0.to(tl.float32) * inv_rms * w0).to(tl.bfloat16), mask=mask)
    w1 = tl.load(w_ptr + TILE + offs)
    tl.store(ybase + TILE + offs, (x1.to(tl.float32) * inv_rms * w1).to(tl.bfloat16), mask=mask)
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

1. Backgroud: For normalization kernels, e.g. RMS Norm, the inputs are loaded twice, one for reduction, the other for normalization. Inductor generates a persistent reduction which loads inputs to registers for once and cache as float32 for computations. Due to register pressure, this feature is only used when the reduction dimension is small, e.g. <= 1024. Recent blog showed the current SOTA is about 6000GB/s while the max on GB200 is about 7928 GB/s.

2. Proposal: Optimize persistent reduction with recompute from cached inputs 2. 1 Register cache (a) Instead of cache inputs as float32, we directly cache them as float16 and only convert to float32 at the time of computation. (b) Issue multiple load instructions at the same time to ensure enough bytes-in-flights

The final generate kernel (using 2 unrolled loads as an example) should look like:

@triton.jit
def rmsnorm_2tile_kernel(x_ptr, w_ptr, y_ptr, M, stride_x, stride_y,
                         EPS: tl.constexpr, TILE: tl.constexpr, BLOCK_M: tl.constexpr):
    pid = tl.program_id(0)
    rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]  # [BLOCK_M, 1]
    mask = rows < M
    offs = tl.arange(0, TILE)[None, :]  # [1, TILE]
    base = x_ptr + rows * stride_x
    x0 = tl.load(base + offs, mask=mask, other=0.0)
    x1 = tl.load(base + TILE + offs, mask=mask, other=0.0)
    ss = x0.to(tl.float32) * x0.to(tl.float32)
    ss += x1.to(tl.float32) * x1.to(tl.float32)
    inv_rms = tl.rsqrt(tl.sum(ss, axis=1)[:, None] / (2 * TILE) + EPS)
    ybase = y_ptr + rows * stride_y
    w0 = tl.load(w_ptr + offs)
    tl.store(ybase + offs, (x0.to(tl.float32) * inv_rms * w0).to(tl.bfloat16), mask=mask)
    w1 = tl.load(w_ptr + TILE + offs)
    tl.store(ybase + TILE + offs, (x1.to(tl.float32) * inv_rms * w1).to(tl.bfloat16), mask=mask)

2.2 Shared memory cache with TMA Register cache still suffers from register pressure when hidden size is larger than 8192, shared memory can be used, which can also benefit from cluster reduction where multiple SMs are used for one reduction row.

3. Preliminary Results: SOL% = achieved bandwidth / 7928 GB/s peak. Measured with torch.profiler (CUPTI) and L2 cache clearing. Baseline is: torch.compile(rms_norm_ref, mode="max-autotune-no-cudagraphs", fullgraph=True) 3.1 Register cache (best gains at common LLM hidden dims):

Shape (M × N)BaselineReg-tiledImprovement
8192 × 819271.8%94.2%+22.4pp
16384 × 409680.8%97.4%+16.6pp
32768 × 204886.4%97.3%+10.9pp

Reg-tiled dominates for N = 512–8192. Compiled wins at N >= 32768 (register spilling).

3.2 Shared memory cache with TMA (constant ~22 registers, no performance cliff at large N):

Shape (M × N)BaselineTMA-smemImprovement
32768 × 409663.8%89.1%+25.3pp
16384 × 819276.6%88.2%+11.6pp
32768 × 819280.2%90.7%+10.5pp
32768 × 3276883.3%87.3%+4.0pp
16384 × 1638490.8%88.1%-2.7pp

Hi @shunting314 and @eellison, what do you think about this proposal?

extent analysis

TL;DR

Implementing a register cache with float16 and utilizing shared memory cache with TMA can significantly improve the performance of normalization kernels.

Guidance

  • To optimize persistent reduction, consider caching inputs as float16 instead of float32 to reduce register pressure.
  • Issue multiple load instructions simultaneously to ensure sufficient bytes-in-flight and improve performance.
  • For larger hidden sizes, utilize shared memory cache with TMA to mitigate register spilling and maintain performance.
  • Evaluate the proposed optimizations using torch.profiler and compare the results to the baseline to measure improvements.

Example

# Example of using float16 cache
x0 = tl.load(base + offs, mask=mask, other=0.0).to(tl.float16)
x1 = tl.load(base + TILE + offs, mask=mask, other=0.0).to(tl.float16)

Notes

The proposed optimizations may have varying effects depending on the specific use case and hardware. Further testing and evaluation are necessary to determine the best approach.

Recommendation

Apply the proposed register cache and shared memory cache with TMA optimizations to improve the performance of normalization kernels, as they have shown significant improvements in preliminary results.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING