pytorch - 💡(How to fix) Fix [inductor] Scheduler phases scale poorly on lowered aten.lstm/aten.gru graphs

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
  • Dynamo tracing works correctly with allow_rnn=True — compiled outputs match eager within float32 precision
  • torch.export is unaffected since it doesn't use inductor
  • Cached compilations are fast (~0.05s), so this is purely a cold-compile cost
  • The scheduler scaling is likely not LSTM-specific — any op that decomposes into 100+ lowered buffers would likely hit similar costs

Error Message

This is a one-time compilation cost (cached runs are fast), but it's high enough that users hitting #158007 who try allow_rnn=True as suggested by the error message will likely give up thinking it's hung.

Root Cause

  • Dynamo tracing works correctly with allow_rnn=True — compiled outputs match eager within float32 precision
  • torch.export is unaffected since it doesn't use inductor
  • Cached compilations are fast (~0.05s), so this is purely a cold-compile cost
  • The scheduler scaling is likely not LSTM-specific — any op that decomposes into 100+ lowered buffers would likely hit similar costs

Code Example

import torch
import torch.nn as nn
import time
import os

os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
torch._dynamo.config.allow_rnn = True

model = nn.LSTM(32, 64, num_layers=2, batch_first=True).cuda().eval()
x = torch.randn(2, 16, 32, device="cuda")

def timing_backend(gm, example_inputs):
    from torch._inductor.compile_fx import compile_fx
    print(f"FX graph: {len(list(gm.graph.nodes))} nodes")
    t0 = time.time()
    result = compile_fx(gm, example_inputs)
    print(f"compile_fx: {time.time() - t0:.1f}s")
    return result

compiled = torch.compile(model, fullgraph=True, backend=timing_backend)
with torch.no_grad():
    out = compiled(x)

from torch._dynamo.utils import compile_times
print(compile_times(repr="str"))

---

PyTorch: 2.13.0a0+git795be92
CUDA: 13.0
GPU: NVIDIA H200
OS: RHEL 9 (kernel 5.14.0-615.el9.x86_64)
Python: 3.12
RAW_BUFFERClick to expand / collapse

Problem

When torch._dynamo.config.allow_rnn = True is set, dynamo traces through nn.LSTM / nn.GRU successfully and produces a small FX graph (12–32 nodes). Inductor then decomposes the single aten.lstm op into a much larger lowered graph. The scheduler phases operating on this expanded graph take a long time — 7s for a 2-layer LSTM, scaling up to 80s+ for 6-layer models, making compilation slow for common RNN configurations.

This is a one-time compilation cost (cached runs are fast), but it's high enough that users hitting #158007 who try allow_rnn=True as suggested by the error message will likely give up thinking it's hung.

Measured data

All measurements from isolated subprocesses with TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 on PyTorch 2.13.0a0+git795be92, NVIDIA H200.

Lowered graph expansion — inductor decomposes the aten.lstm op into many buffers:

ConfigFX nodes (dynamo output)Lowered buffers (inductor)Scheduler nodes
LSTM 1L, h=5121510
LSTM 1L, h=64126736
LSTM 2L, h=641613472
GRU 1L, h=64106736

12 FX nodes become 67 lowered buffers for a 1-layer LSTM with h=64. Adding a second layer doubles this to 134.

Where the time goes — phase breakdown from PyTorch's own compile_times():

PhaseLSTM 1L h=5LSTM 1L h=64LSTM 2L h=64GRU 1L h=64
GraphLowering.run0.14s0.71s0.89s0.49s
Scheduler.__init__0.27s0.74s1.44s0.71s
Scheduler.fused_nodes0.03s0.43s0.93s0.47s
Scheduler.codegen0.18s0.74s1.40s0.64s
joint_graph_passes.pass_pattern_00.47s0.49s1.06s0.62s
compile_fx total2.56s4.30s6.81s4.39s

The dominant cost is the scheduler: Scheduler.__init__ + Scheduler.fused_nodes + Scheduler.codegen together account for ~55% of compile_fx time and roughly double when the buffer count doubles (1L→2L).

For larger models (6-layer, h=256), earlier wall-clock measurements showed compile_fx taking 80+ seconds.

Reproduction

import torch
import torch.nn as nn
import time
import os

os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1"
torch._dynamo.config.allow_rnn = True

model = nn.LSTM(32, 64, num_layers=2, batch_first=True).cuda().eval()
x = torch.randn(2, 16, 32, device="cuda")

def timing_backend(gm, example_inputs):
    from torch._inductor.compile_fx import compile_fx
    print(f"FX graph: {len(list(gm.graph.nodes))} nodes")
    t0 = time.time()
    result = compile_fx(gm, example_inputs)
    print(f"compile_fx: {time.time() - t0:.1f}s")
    return result

compiled = torch.compile(model, fullgraph=True, backend=timing_backend)
with torch.no_grad():
    out = compiled(x)

from torch._dynamo.utils import compile_times
print(compile_times(repr="str"))

Expected output: compile_fx: ~7s, with Scheduler.__init__ and Scheduler.codegen as the top phases.

Context

  • Dynamo tracing works correctly with allow_rnn=True — compiled outputs match eager within float32 precision
  • torch.export is unaffected since it doesn't use inductor
  • Cached compilations are fast (~0.05s), so this is purely a cold-compile cost
  • The scheduler scaling is likely not LSTM-specific — any op that decomposes into 100+ lowered buffers would likely hit similar costs

Related

Environment

PyTorch: 2.13.0a0+git795be92
CUDA: 13.0
GPU: NVIDIA H200
OS: RHEL 9 (kernel 5.14.0-615.el9.x86_64)
Python: 3.12

cc @jerryzh168 @mikaylagawarecki @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @oulgen @jamesjwu @aorenste @anijain2305 @laithsakka @masnesral @aditvenk @chenyang78

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING