pytorch - 💡(How to fix) Fix `torch.compile` crashes on batched matmul inside `torch.inference_mode()` [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181512Fetched 2026-04-27 05:28:50
View on GitHub
Comments
0
Participants
1
Timeline
25
Reactions
0
Author
Participants
Timeline (top)
mentioned ×9subscribed ×9labeled ×7

Error Message

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

While executing %matmul : [num_users=1] = call_function[target=operator.matmul](args = (%l_x_, %l_w_), kwargs = {})

Fix Action

Fix / Workaround

Traceback (most recent call last):
  File "repro.py", line 13, in <module>
    torch.compile(f)(x, w)
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1069, in compile_wrapper
    raise e.remove_dynamo_frames() from None
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1052, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2480, in __call__
    result = self._torchdynamo_orig_backend(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2198, in __call__
    result = self._inner_convert(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 758, in __call__
    result = _compile(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1983, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1593, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1652, in _compile_inner
    dynamo_output = compile_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1500, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 1626, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1472, in transform
    tracer_output = trace_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 933, in trace_frame
    run_tracer()
  File "/pytorch/torch/_dynamo/convert_frame.py", line 914, in run_tracer
    tracer.run()
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1839, in run
    while self.step():
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1506, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5211, in RETURN_VALUE
    self._return(inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5193, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2114, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2730, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/pytorch/torch/_dynamo/output_graph.py", line 2897, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/pytorch/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified
    aot_state = create_aot_state(
  File "/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/pytorch/torch/fx/interpreter.py", line 224, in run
    raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

Code Example

import torch

x = torch.randn(2, 4, 8, device='cuda')
w = torch.randn(8, 8, device='cuda')

def f(x, w):
    with torch.inference_mode():
        return (x @ w).sum()

torch._dynamo.reset()
torch.compile(f)(x, w)
# BackendCompilerFailed: RuntimeError: <weakref at ...; to 'torch.storage.UntypedStorage' at ...>

---

x = torch.randn(2, 4, 8)  # CPU
w = torch.randn(8, 8)
# same crash

---

import torch

x3d = torch.randn(2, 4, 8, device='cuda')
x2d = torch.randn(4, 8, device='cuda')
w = torch.randn(8, 8, device='cuda')

# Control 1: 2D matmul + inference_mode → OK
torch._dynamo.reset()
print(torch.compile(lambda x, w: torch.inference_mode()(lambda: (x @ w).sum())())(x2d, w))

# Control 2: 3D matmul, no inference_mode → OK
torch._dynamo.reset()
print(torch.compile(lambda x, w: (x @ w).sum())(x3d, w))

# Control 3: 3D matmul + no_grad → OK
torch._dynamo.reset()
def g(x, w):
    with torch.no_grad():
        return (x @ w).sum()
print(torch.compile(g)(x3d, w))

---

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

While executing %matmul : [num_users=1] = call_function[target=operator.matmul](args = (%l_x_, %l_w_), kwargs = {})

---

Traceback (most recent call last):
  File "repro.py", line 13, in <module>
    torch.compile(f)(x, w)
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1069, in compile_wrapper
    raise e.remove_dynamo_frames() from None
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1052, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2480, in __call__
    result = self._torchdynamo_orig_backend(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2198, in __call__
    result = self._inner_convert(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 758, in __call__
    result = _compile(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1983, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1593, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1652, in _compile_inner
    dynamo_output = compile_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1500, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 1626, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1472, in transform
    tracer_output = trace_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 933, in trace_frame
    run_tracer()
  File "/pytorch/torch/_dynamo/convert_frame.py", line 914, in run_tracer
    tracer.run()
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1839, in run
    while self.step():
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1506, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5211, in RETURN_VALUE
    self._return(inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5193, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2114, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2730, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/pytorch/torch/_dynamo/output_graph.py", line 2897, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/pytorch/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified
    aot_state = create_aot_state(
  File "/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/pytorch/torch/fx/interpreter.py", line 224, in run
    raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

While executing %matmul : [num_users=1] = call_function[target=operator.matmul](args = (%l_x_, %l_w_), kwargs = {})

---

PyTorch: 2.13.0.dev20260425+cu126
Python: 3.11.15
OS: Linux-5.4.0-42-generic-x86_64-with-glibc2.31
CUDA: 12.6
Triton: 3.7.0
GPU: Tesla T4 (sm_75) — also crashes on CPU
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

torch.compile crashes when compiling a function that performs batched matrix multiplication (3D+ inputs) inside torch.inference_mode().

The crash occurs in AOT Autograd's run_functionalized_fw_and_collect_metadata, where a weakref to an UntypedStorage cannot be resolved.

All of the following work individually — only their combination triggers the crash:

  • 2D matmul + inference_mode + torch.compile → ✅
  • 3D matmul + torch.compile (no inference_mode) → ✅
  • 3D matmul + torch.no_grad() + torch.compile → ✅
  • 3D matmul + inference_mode (eager, no compile) → ✅
  • 3D matmul + inference_mode + torch.compile → ❌ crash

Minimal reproducer

import torch

x = torch.randn(2, 4, 8, device='cuda')
w = torch.randn(8, 8, device='cuda')

def f(x, w):
    with torch.inference_mode():
        return (x @ w).sum()

torch._dynamo.reset()
torch.compile(f)(x, w)
# BackendCompilerFailed: RuntimeError: <weakref at ...; to 'torch.storage.UntypedStorage' at ...>

Also crashes on CPU:

x = torch.randn(2, 4, 8)  # CPU
w = torch.randn(8, 8)
# same crash

Controls (all pass)

import torch

x3d = torch.randn(2, 4, 8, device='cuda')
x2d = torch.randn(4, 8, device='cuda')
w = torch.randn(8, 8, device='cuda')

# Control 1: 2D matmul + inference_mode → OK
torch._dynamo.reset()
print(torch.compile(lambda x, w: torch.inference_mode()(lambda: (x @ w).sum())())(x2d, w))

# Control 2: 3D matmul, no inference_mode → OK
torch._dynamo.reset()
print(torch.compile(lambda x, w: (x @ w).sum())(x3d, w))

# Control 3: 3D matmul + no_grad → OK
torch._dynamo.reset()
def g(x, w):
    with torch.no_grad():
        return (x @ w).sum()
print(torch.compile(g)(x3d, w))

Trigger conditions

ConditionRequired?Notes
torch.inference_mode()Yestorch.no_grad() does NOT trigger the crash
Input ndim≥ 32D matmul works fine
DeviceAnyBoth CUDA and CPU crash
dynamic=TrueNoCrashes with both static and dynamic
Matmul variantAny@, torch.bmm, torch.matmul all crash
dtypeAnyTested fp32, fp64

Error message

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

While executing %matmul : [num_users=1] = call_function[target=operator.matmul](args = (%l_x_, %l_w_), kwargs = {})

Full stack trace

<details> <summary>Click to expand</summary>
Traceback (most recent call last):
  File "repro.py", line 13, in <module>
    torch.compile(f)(x, w)
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1069, in compile_wrapper
    raise e.remove_dynamo_frames() from None
  File "/pytorch/torch/_dynamo/eval_frame.py", line 1052, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2480, in __call__
    result = self._torchdynamo_orig_backend(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 2198, in __call__
    result = self._inner_convert(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 758, in __call__
    result = _compile(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1983, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1593, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1652, in _compile_inner
    dynamo_output = compile_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1500, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/pytorch/torch/_dynamo/bytecode_transformation.py", line 1626, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/pytorch/torch/_dynamo/convert_frame.py", line 1472, in transform
    tracer_output = trace_frame(
  File "/pytorch/torch/_dynamo/convert_frame.py", line 933, in trace_frame
    run_tracer()
  File "/pytorch/torch/_dynamo/convert_frame.py", line 914, in run_tracer
    tracer.run()
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1839, in run
    while self.step():
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 1506, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5211, in RETURN_VALUE
    self._return(inst)
  File "/pytorch/torch/_dynamo/symbolic_convert.py", line 5193, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2114, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2730, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/pytorch/torch/_dynamo/output_graph.py", line 2897, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/pytorch/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/pytorch/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified
    aot_state = create_aot_state(
  File "/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/pytorch/torch/fx/interpreter.py", line 224, in run
    raise RuntimeError(*e.args) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7fda0b097f60; to 'torch.storage.UntypedStorage' at 0x7fda0b0b4190>

While executing %matmul : [num_users=1] = call_function[target=operator.matmul](args = (%l_x_, %l_w_), kwargs = {})
</details>

Versions

PyTorch: 2.13.0.dev20260425+cu126
Python: 3.11.15
OS: Linux-5.4.0-42-generic-x86_64-with-glibc2.31
CUDA: 12.6
Triton: 3.7.0
GPU: Tesla T4 (sm_75) — also crashes on CPU

cc @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The issue can be worked around by avoiding the use of torch.inference_mode() when compiling a function with torch.compile() that performs batched matrix multiplication on 3D inputs.

Guidance

  • Verify that the crash occurs when using torch.inference_mode() with torch.compile() on 3D matrix multiplication.
  • Try replacing torch.inference_mode() with torch.no_grad() to see if the issue persists.
  • Test the function without torch.compile() to ensure it works as expected.
  • Consider filing a bug report or seeking further assistance from the PyTorch community, as this issue may be related to a specific version or configuration of PyTorch.

Example

import torch

x = torch.randn(2, 4, 8, device='cuda')
w = torch.randn(8, 8, device='cuda')

def f(x, w):
    with torch.no_grad():  # Replace torch.inference_mode() with torch.no_grad()
        return (x @ w).sum()

torch._dynamo.reset()
torch.compile(f)(x, w)

Notes

The provided example code and stack trace suggest a potential issue with the interaction between torch.inference_mode() and torch.compile() when performing batched matrix multiplication on 3D inputs. However, without further information or a more detailed analysis, it is difficult to provide a definitive solution.

Recommendation

Apply the workaround by replacing torch.inference_mode() with torch.no_grad() when compiling the function with torch.compile(). This may help avoid the crash, but it is essential to verify that the function behaves as expected and produces the correct results.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile` crashes on batched matmul inside `torch.inference_mode()` [1 participants]