pytorch - 💡(How to fix) Fix [Inductor][Bucketing] Nested bucket traces should ignore ambient pending unbacked symbols [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Inductor collective bucketing can fail when its nested make_fx trace observes pending unbacked symbols from an outer fake/dynamic-shape trace.

Bucketing traces small collective rewrite helpers while reusing the surrounding FakeTensorMode. If that shared ShapeEnv already has pending fresh unbacked symbols from an unrelated outer trace, compute_unbacked_bindings can require those ambient symbols to be bound by the bucket trace outputs. The bucket helper itself may only operate on static tensors, so those symbols are unrelated and cannot be bound by the nested trace.

Root Cause

Inductor collective bucketing can fail when its nested make_fx trace observes pending unbacked symbols from an outer fake/dynamic-shape trace.

Bucketing traces small collective rewrite helpers while reusing the surrounding FakeTensorMode. If that shared ShapeEnv already has pending fresh unbacked symbols from an unrelated outer trace, compute_unbacked_bindings can require those ambient symbols to be bound by the bucket trace outputs. The bucket helper itself may only operate on static tensors, so those symbols are unrelated and cannot be bound by the nested trace.

Fix Action

Fixed

Code Example

import torch
from torch._C import FileCheck
from torch._inductor.fx_passes.bucketing import _trace as bucketing_trace
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv


fake_mode = FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv())
x = fake_mode.from_tensor(torch.randn(2, 2), static_shapes=True)

ambient = fake_mode.shape_env.create_unbacked_symint().node.expr
assert ambient in fake_mode.shape_env.pending_fresh_unbacked_symbols

gm = bucketing_trace(lambda x: x + 1, (x,))

assert ambient in fake_mode.shape_env.pending_fresh_unbacked_symbols
FileCheck().check("aten.add").run(gm.code)
print(gm.code)
RAW_BUFFERClick to expand / collapse

Summary

Inductor collective bucketing can fail when its nested make_fx trace observes pending unbacked symbols from an outer fake/dynamic-shape trace.

Bucketing traces small collective rewrite helpers while reusing the surrounding FakeTensorMode. If that shared ShapeEnv already has pending fresh unbacked symbols from an unrelated outer trace, compute_unbacked_bindings can require those ambient symbols to be bound by the bucket trace outputs. The bucket helper itself may only operate on static tensors, so those symbols are unrelated and cannot be bound by the nested trace.

Reproducer

import torch
from torch._C import FileCheck
from torch._inductor.fx_passes.bucketing import _trace as bucketing_trace
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv


fake_mode = FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv())
x = fake_mode.from_tensor(torch.randn(2, 2), static_shapes=True)

ambient = fake_mode.shape_env.create_unbacked_symint().node.expr
assert ambient in fake_mode.shape_env.pending_fresh_unbacked_symbols

gm = bucketing_trace(lambda x: x + 1, (x,))

assert ambient in fake_mode.shape_env.pending_fresh_unbacked_symbols
FileCheck().check("aten.add").run(gm.code)
print(gm.code)

Expected behavior

The nested bucketing trace should only be responsible for symbols created by the bucketing trace itself. Ambient pending unbacked symbols from the surrounding trace should remain pending for the outer trace and should not be required in the bucket helper outputs.

Actual behavior

The nested make_fx trace can fail with PendingUnbackedSymbolNotFound for an ambient unbacked symbol that is unrelated to the bucket inputs and outputs.

Notes

The bucketing trace should snapshot and clear pending/ignorable fresh unbacked symbols around the nested make_fx call, then restore the ambient state afterward.

cc @chauhang @penguinwu @eellison @aorenste @ezyang @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @bdhirsh

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

The nested bucketing trace should only be responsible for symbols created by the bucketing trace itself. Ambient pending unbacked symbols from the surrounding trace should remain pending for the outer trace and should not be required in the bucket helper outputs.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor][Bucketing] Nested bucket traces should ignore ambient pending unbacked symbols [1 pull requests]