pytorch - 💡(How to fix) Fix `make_fx` generates broken input codegen for a function with a `(tuple, dict)` positional signature [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

RuntimeError: aten::add() Expected a value of type 'Tensor' for argument 'self' but instead found type 'tuple'. Position: 0 Value: (tensor([-1.1176, -1.2213, -0.8301]),) Declaration: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor Cast error details: Unable to cast (tensor([-1.1176, -1.2213, -0.8301]),) to Tensor

Root Cause

Root cause: _PyTreeCodeGen.gen_var_bindings (torch/fx/graph.py) treats any in_spec of tuple[tuple, dict] as an (args, kwargs) call and splits it into ([positionals], {kwargs}). A genuine 2-arg signature f(tuple, dict) produces the same spec, so it's misread as args+kwargs. The bug is specific to tuple-before-dict: using a list-first or tensor-first arg, two tuples, or dict-first all correctly emit [a, d] and work.

Fix Action

Fixed

Code Example

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(a, d):                      # a: tuple, d: dict
    return a[0] + d["x"]

gm = make_fx(f)((torch.randn(3),), {"x": torch.randn(3)})
print(gm.code)                    # see broken template below
gm((torch.randn(3),), {"x": torch.randn(3)})   # raises

---

def forward(self, a, d):
    a_1, d_1, = fx_pytree.tree_flatten_spec(([a], {'x':d}),
  self._in_spec)
    ...

---

RuntimeError: aten::add() Expected a value of type 'Tensor' for argument 'self' but instead found type 'tuple'.
Position: 0
Value: (tensor([-1.1176, -1.2213, -0.8301]),)
Declaration: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
Cast error details: Unable to cast (tensor([-1.1176, -1.2213, -0.8301]),) to Tensor

---

PyTorch version: 2.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.5 (arm64)
GCC version: Could not collect
Clang version: 21.0.0 (clang-2100.1.1.101)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.6 | packaged by Anaconda, Inc. | (main, Oct  3 2024, 02:26:31) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-26.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M5

Versions of relevant libraries:
[pip3] backpack-for-pytorch==1.7.1
[pip3] jet-for-pytorch==0.0.2.dev133+gf171ce318
[pip3] numpy==2.4.6
[pip3] torch==2.12.0
[pip3] torchvision==0.27.0
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When a traced function's first positional argument is a tuple and a later positional argument is a dict, make_fx generates a wrong input-reconstruction template in the resulting forward, so calling the traced module raises.

MWE:

import torch
from torch.fx.experimental.proxy_tensor import make_fx

def f(a, d):                      # a: tuple, d: dict
    return a[0] + d["x"]

gm = make_fx(f)((torch.randn(3),), {"x": torch.randn(3)})
print(gm.code)                    # see broken template below
gm((torch.randn(3),), {"x": torch.randn(3)})   # raises

The generated forward is:

def forward(self, a, d):
    a_1, d_1, = fx_pytree.tree_flatten_spec(([a], {'x':d}),
  self._in_spec)
    ...

The template should be [a, d]; instead a is wrapped in a list and d is rebuilt under its own key, so a (the tuple) is passed where a tensor is expected:

RuntimeError: aten::add() Expected a value of type 'Tensor' for argument 'self' but instead found type 'tuple'.
Position: 0
Value: (tensor([-1.1176, -1.2213, -0.8301]),)
Declaration: aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
Cast error details: Unable to cast (tensor([-1.1176, -1.2213, -0.8301]),) to Tensor

Root cause: _PyTreeCodeGen.gen_var_bindings (torch/fx/graph.py) treats any in_spec of tuple[tuple, dict] as an (args, kwargs) call and splits it into ([positionals], {kwargs}). A genuine 2-arg signature f(tuple, dict) produces the same spec, so it's misread as args+kwargs. The bug is specific to tuple-before-dict: using a list-first or tensor-first arg, two tuples, or dict-first all correctly emit [a, d] and work.

Versions

PyTorch version: 2.12.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.5 (arm64)
GCC version: Could not collect
Clang version: 21.0.0 (clang-2100.1.1.101)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.6 | packaged by Anaconda, Inc. | (main, Oct  3 2024, 02:26:31) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-26.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M5

Versions of relevant libraries:
[pip3] backpack-for-pytorch==1.7.1
[pip3] jet-for-pytorch==0.0.2.dev133+gf171ce318
[pip3] numpy==2.4.6
[pip3] torch==2.12.0
[pip3] torchvision==0.27.0
[conda] Could not collect

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `make_fx` generates broken input codegen for a function with a `(tuple, dict)` positional signature [1 pull requests]