pytorch - ✅(Solved) Fix `torch.cond` with tensors created in `[true/false]_fn` fails during `run_decompositions` [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180354Fetched 2026-04-16 06:35:05
View on GitHub
Comments
1
Participants
2
Timeline
20
Reactions
0
Author
Participants
Timeline (top)
mentioned ×8subscribed ×8labeled ×3commented ×1

Error Message

Traceback (most recent call last): File "bug.py", line 12, in <module> ep.run_decompositions() ~~~~~~~~~~~~~~~~~~~~~^^ File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 124, in wrapper return fn(args, **kwargs) File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1530, in run_decompositions return _decompose_exported_program( self, ...<3 lines>... decompose_custom_triton_ops=decompose_custom_triton_ops, ) File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1005, in _decompose_exported_program ) = _decompose_and_get_gm_with_new_signature_constants( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ ep, ^^^ ...<3 lines>... decompose_custom_triton_ops=decompose_custom_triton_ops, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 483, in _decompose_and_get_gm_with_new_signature_constants aten_export_artifact = _export_to_aten_ir( patched_mod, ...<6 lines>... decompose_custom_triton_ops=decompose_custom_triton_ops, ) File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 1042, in _export_to_aten_ir gm, graph_signature = transform(_aot_export_joint_with_descriptors)( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ stack, ^^^^^^ ...<5 lines>... _record_nn_module_stack=True, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 971, in _aot_export_joint_with_descriptors joint_with_descriptors = aot_export_joint_with_descriptors( stack, ...<4 lines>... _record_nn_module_stack=_record_nn_module_stack, ) File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 1414, in aot_export_joint_with_descriptors aot_state = create_aot_state( stack, ...<5 lines>... shape_env, ) File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 582, in create_aot_state fw_metadata = run_functionalized_fw_and_collect_metadata( ...<4 lines>... pre_dispatch=aot_config.pre_dispatch, )(_dup_fake_script_obj(fake_flat_args)) File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner flat_f_outs = f(*flat_f_args) File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn tree_out = fn(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call out = PropagateUnbackedSymInts(mod).run(*args) File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 197, in run self.env[node] = self.run_node(node) ~~~~~~~~~~~~~^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8537, in run_node result = super().run_node(n) File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 294, in run_node return getattr(self, n.op)(n.target, args, kwargs) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 377, in call_function return target(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in call return super().call(pred, true_fn, false_fn, operands) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_ops.py", line 386, in dispatch return kernel(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_ops.py", line 336, in maybe_run_autograd return self(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in call return super().call(pred, true_fn, false_fn, operands) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_ops.py", line 422, in dispatch result = handler(mode, *args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 721, in cond_func hop_instance = HopInstance.create(cond_op, pred, true_fn, false_fn, inputs) File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1198, in create return HopInstance(hop, hop.gen_schema(*args, **kwargs)) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 61, in gen_schema then_gm: torch.fx.GraphModule = materialize_as_graph(true_fn, operands) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1283, in materialize_as_graph gm = _materialize_as_graph_inner() File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn return fn(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1279, in _materialize_as_graph_inner return _maybe_reenter_make_fx( ~~~~~~~~~~~~~~~~~~~~~~~ fn, subgraph_decomp_table=subgraph_decomp_table ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ )(*unfunc_t) ~^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 152, in wrapped return make_fx(fn, decomposition_table=subgraph_decomp_table)(*args) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped return make_fx_tracer.trace(f, *args) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace return self._trace_inner(f, *args) ~~~~~~~~~~~~~~~~~^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner t = dispatch_trace( wrap_key(func, args, self.fx_tracer, self.pre_dispatch), tracer=self.fx_tracer, concrete_args=tuple(phs), ) File "/.../lib/python3.14/site-packages/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn return fn(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn return fn(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 912, in trace (self.create_arg(fn(*args)),), ~~^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped out = f(*tensors) # type:ignore[call-arg] File "<string>", line 1, in <lambda> File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped return self._wrapped_call(self, *args, **kwargs) ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 461, in call raise e File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 447, in call return super(self.cls, obj).call(*args, **kwargs) # type: ignore[misc] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1427, in call_module return forward(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 879, in forward return _orig_module_call(mod, *args, **kwargs) File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) File "<eval_with_key>.54 from /.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py:1720 in wrapped", line 6, in forward lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in call return self._op(*args, **kwargs) ~~~~~~~~^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in torch_function return func(*args, **kwargs) File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in call return self._op(*args, **kwargs) ~~~~~~~~^^^^^^^^^^^^^^^^^ File "/.../lib/python3.14/site-packages/torch/_subclasses/functional_tensor.py", line 280, in torch_dispatch raise RuntimeError( "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" ) RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()

While executing %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%x,)), kwargs = {}) Original traceback: File "bug.py", line 9, in forward return torch.cond(x.any(), branch, branch, (x,)) File "<eval_with_key>.5", line 9, in forward cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Fix Action

Fix / Workaround

Traceback (most recent call last):
  File "bug.py", line 12, in <module>
    ep.run_decompositions()
    ~~~~~~~~~~~~~~~~~~~~~^^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1530, in run_decompositions
    return _decompose_exported_program(
        self,
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1005, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        ep,
        ^^^
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 483, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
        patched_mod,
    ...<6 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 1042, in _export_to_aten_ir
    gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        stack,
        ^^^^^^
    ...<5 lines>...
        _record_nn_module_stack=True,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 971, in _aot_export_joint_with_descriptors
    joint_with_descriptors = aot_export_joint_with_descriptors(
        stack,
    ...<4 lines>...
        _record_nn_module_stack=_record_nn_module_stack,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 1414, in aot_export_joint_with_descriptors
    aot_state = create_aot_state(
        stack,
    ...<5 lines>...
        shape_env,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
    ...<4 lines>...
        pre_dispatch=aot_config.pre_dispatch,
    )(*_dup_fake_script_obj(fake_flat_args))
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 197, in run
    self.env[node] = self.run_node(node)
                     ~~~~~~~~~~~~~^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8537, in run_node
    result = super().run_node(n)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 294, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 377, in call_function
    return target(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 386, in dispatch
    return kernel(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 336, in maybe_run_autograd
    return self(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 422, in dispatch
    result = handler(mode, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 721, in cond_func
    hop_instance = HopInstance.create(cond_op, pred, true_fn, false_fn, inputs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1198, in create
    return HopInstance(hop, hop.gen_schema(*args, **kwargs))
                            ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 61, in gen_schema
    then_gm: torch.fx.GraphModule = materialize_as_graph(true_fn, operands)
                                    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1283, in materialize_as_graph
    gm = _materialize_as_graph_inner()
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1279, in _materialize_as_graph_inner
    return _maybe_reenter_make_fx(
           ~~~~~~~~~~~~~~~~~~~~~~~
        fn, subgraph_decomp_table=subgraph_decomp_table
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*unfunc_t)
    ~^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 152, in wrapped
    return make_fx(fn, decomposition_table=subgraph_decomp_table)(*args)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
    return make_fx_tracer.trace(f, *args)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
    return self._trace_inner(f, *args)
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
    t = dispatch_trace(
        wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
        tracer=self.fx_tracer,
        concrete_args=tuple(phs),
    )
  File "/.../lib/python3.14/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 912, in trace
    (self.create_arg(fn(*args)),),
                     ~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 461, in __call__
    raise e
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 447, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1427, in call_module
    return forward(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.54 from /.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py:1720 in wrapped", line 6, in forward
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in __torch_function__
    return func(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_subclasses/functional_tensor.py", line 280, in __torch_dispatch__
    raise RuntimeError(
        "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
    )
RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()

PR fix notes

PR #180570: Fix torch.cond crash with FunctionalTensor constants (#180354)

Description (problem / solution / changelog)

Fixes #180354

Summary

Fixed a crash in torch.cond when branch functions contain tensor constants that become FunctionalTensor attributes in the GraphModule during export and decomposition.

Problem

When ExportedProgram.run_decompositions() was called on a program containing torch.cond with tensor constants in branches (e.g., torch.tensor(0)), it would crash with:

Root Cause

Tensor constants in cond branches are stored as attributes (e.g., _tensor_constant0) in the GraphModule and become FunctionalTensors. During make_fx() tracing in _maybe_fake_tracing(), these FunctionalTensors were accessed without FunctionalTensorMode being active, causing the crash.

Solution

This PR enables FunctionalTensorMode during make_fx() tracing in _maybe_fake_tracing() to safely handle any FunctionalTensors present in the GraphModule. A guard check was added to prevent double-functionalization when the mode is already active.

Changes:

  • Modified torch/_higher_order_ops/utils.py::_maybe_fake_tracing() to conditionally enable FunctionalTensorMode
  • Added guard to detect if FunctionalTensorMode is already active using torch._C._get_dispatch_mode()
  • Used conditional context manager pattern for clean code

Testing

  • Added regression test test_cond_with_tensor_constants_in_branches in test/dynamo/test_higher_order_ops.py
  • Verified the crash is fixed with the reproduction case from #180354
  • Test covers both true and false branches with torch.compile

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98

Changed files

  • test/dynamo/test_higher_order_ops.py (modified, +33/-0)
  • torch/_higher_order_ops/utils.py (modified, +16/-1)

Code Example

import torch

def branch(x):
    torch.tensor(0)  # works without this line
    return x.clone()

class Module(torch.nn.Module):
    def forward(self, x):
        return torch.cond(x.any(), branch, branch, (x,))

ep = torch.export.export(Module(), (torch.empty(()),))
ep.run_decompositions()

---

Traceback (most recent call last):
  File "bug.py", line 12, in <module>
    ep.run_decompositions()
    ~~~~~~~~~~~~~~~~~~~~~^^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1530, in run_decompositions
    return _decompose_exported_program(
        self,
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1005, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        ep,
        ^^^
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 483, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
        patched_mod,
    ...<6 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 1042, in _export_to_aten_ir
    gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        stack,
        ^^^^^^
    ...<5 lines>...
        _record_nn_module_stack=True,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 971, in _aot_export_joint_with_descriptors
    joint_with_descriptors = aot_export_joint_with_descriptors(
        stack,
    ...<4 lines>...
        _record_nn_module_stack=_record_nn_module_stack,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 1414, in aot_export_joint_with_descriptors
    aot_state = create_aot_state(
        stack,
    ...<5 lines>...
        shape_env,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
    ...<4 lines>...
        pre_dispatch=aot_config.pre_dispatch,
    )(*_dup_fake_script_obj(fake_flat_args))
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 197, in run
    self.env[node] = self.run_node(node)
                     ~~~~~~~~~~~~~^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8537, in run_node
    result = super().run_node(n)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 294, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 377, in call_function
    return target(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 386, in dispatch
    return kernel(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 336, in maybe_run_autograd
    return self(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 422, in dispatch
    result = handler(mode, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 721, in cond_func
    hop_instance = HopInstance.create(cond_op, pred, true_fn, false_fn, inputs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1198, in create
    return HopInstance(hop, hop.gen_schema(*args, **kwargs))
                            ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 61, in gen_schema
    then_gm: torch.fx.GraphModule = materialize_as_graph(true_fn, operands)
                                    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1283, in materialize_as_graph
    gm = _materialize_as_graph_inner()
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1279, in _materialize_as_graph_inner
    return _maybe_reenter_make_fx(
           ~~~~~~~~~~~~~~~~~~~~~~~
        fn, subgraph_decomp_table=subgraph_decomp_table
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*unfunc_t)
    ~^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 152, in wrapped
    return make_fx(fn, decomposition_table=subgraph_decomp_table)(*args)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
    return make_fx_tracer.trace(f, *args)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
    return self._trace_inner(f, *args)
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
    t = dispatch_trace(
        wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
        tracer=self.fx_tracer,
        concrete_args=tuple(phs),
    )
  File "/.../lib/python3.14/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 912, in trace
    (self.create_arg(fn(*args)),),
                     ~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 461, in __call__
    raise e
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 447, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1427, in call_module
    return forward(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.54 from /.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py:1720 in wrapped", line 6, in forward
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in __torch_function__
    return func(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_subclasses/functional_tensor.py", line 280, in __torch_dispatch__
    raise RuntimeError(
        "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
    )
RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()

While executing %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
Original traceback:
File "bug.py", line 9, in forward
    return torch.cond(x.any(), branch, branch, (x,))
  File "<eval_with_key>.5", line 9, in forward
    cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

---

Collecting environment information...
PyTorch version: 2.12.0.dev20260414
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.3.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 4.2.3
Libc version: N/A

Python version: 3.14.2 (v3.14.2:df793163d58, Dec  5 2025, 12:18:06) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-26.3.1-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

import torch

def branch(x):
    torch.tensor(0)  # works without this line
    return x.clone()

class Module(torch.nn.Module):
    def forward(self, x):
        return torch.cond(x.any(), branch, branch, (x,))

ep = torch.export.export(Module(), (torch.empty(()),))
ep.run_decompositions()

Error logs

Traceback (most recent call last):
  File "bug.py", line 12, in <module>
    ep.run_decompositions()
    ~~~~~~~~~~~~~~~~~~~~~^^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1530, in run_decompositions
    return _decompose_exported_program(
        self,
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 1005, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        ep,
        ^^^
    ...<3 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/exported_program.py", line 483, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
        patched_mod,
    ...<6 lines>...
        decompose_custom_triton_ops=decompose_custom_triton_ops,
    )
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 1042, in _export_to_aten_ir
    gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        stack,
        ^^^^^^
    ...<5 lines>...
        _record_nn_module_stack=True,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/.../lib/python3.14/site-packages/torch/export/_trace.py", line 971, in _aot_export_joint_with_descriptors
    joint_with_descriptors = aot_export_joint_with_descriptors(
        stack,
    ...<4 lines>...
        _record_nn_module_stack=_record_nn_module_stack,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 1414, in aot_export_joint_with_descriptors
    aot_state = create_aot_state(
        stack,
    ...<5 lines>...
        shape_env,
    )
  File "/.../lib/python3.14/site-packages/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
    ...<4 lines>...
        pre_dispatch=aot_config.pre_dispatch,
    )(*_dup_fake_script_obj(fake_flat_args))
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/utils.py", line 192, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 197, in run
    self.env[node] = self.run_node(node)
                     ~~~~~~~~~~~~~^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8537, in run_node
    result = super().run_node(n)
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 294, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/interpreter.py", line 377, in call_function
    return target(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 386, in dispatch
    return kernel(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 336, in maybe_run_autograd
    return self(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 54, in __call__
    return super().__call__(pred, true_fn, false_fn, operands)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 422, in dispatch
    result = handler(mode, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 721, in cond_func
    hop_instance = HopInstance.create(cond_op, pred, true_fn, false_fn, inputs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1198, in create
    return HopInstance(hop, hop.gen_schema(*args, **kwargs))
                            ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/cond.py", line 61, in gen_schema
    then_gm: torch.fx.GraphModule = materialize_as_graph(true_fn, operands)
                                    ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1283, in materialize_as_graph
    gm = _materialize_as_graph_inner()
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 1279, in _materialize_as_graph_inner
    return _maybe_reenter_make_fx(
           ~~~~~~~~~~~~~~~~~~~~~~~
        fn, subgraph_decomp_table=subgraph_decomp_table
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    )(*unfunc_t)
    ~^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_higher_order_ops/utils.py", line 152, in wrapped
    return make_fx(fn, decomposition_table=subgraph_decomp_table)(*args)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2965, in wrapped
    return make_fx_tracer.trace(f, *args)
           ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2867, in trace
    return self._trace_inner(f, *args)
           ~~~~~~~~~~~~~~~~~^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 2828, in _trace_inner
    t = dispatch_trace(
        wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
        tracer=self.fx_tracer,
        concrete_args=tuple(phs),
    )
  File "/.../lib/python3.14/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1673, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "/.../lib/python3.14/site-packages/torch/_dynamo/eval_frame.py", line 1280, in _fn
    return fn(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 912, in trace
    (self.create_arg(fn(*args)),),
                     ~~^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1743, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "<string>", line 1, in <lambda>
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 461, in __call__
    raise e
  File "/.../lib/python3.14/site-packages/torch/fx/graph_module.py", line 447, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 886, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1427, in call_module
    return forward(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/fx/_symbolic_trace.py", line 879, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.54 from /.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py:1720 in wrapped", line 6, in forward
    lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/fx/experimental/proxy_tensor.py", line 1798, in __torch_function__
    return func(*args, **kwargs)
  File "/.../lib/python3.14/site-packages/torch/_ops.py", line 871, in __call__
    return self._op(*args, **kwargs)
           ~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/.../lib/python3.14/site-packages/torch/_subclasses/functional_tensor.py", line 280, in __torch_dispatch__
    raise RuntimeError(
        "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()"
    )
RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()

While executing %cond : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%x,)), kwargs = {})
Original traceback:
File "bug.py", line 9, in forward
    return torch.cond(x.any(), branch, branch, (x,))
  File "<eval_with_key>.5", line 9, in forward
    cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,));  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Versions

Collecting environment information...
PyTorch version: 2.12.0.dev20260414
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.3.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: version 4.2.3
Libc version: N/A

Python version: 3.14.2 (v3.14.2:df793163d58, Dec  5 2025, 12:18:06) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-26.3.1-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect

cc @chauhang @penguinwu @avikchaudhuri @zhxchen17 @tugsbayasgalan @angelayi @ydwu4

extent analysis

TL;DR

The issue is likely caused by the incorrect usage of torch.cond in the forward method of the Module class, and can be fixed by modifying the branch function to avoid using torch.tensor directly.

Guidance

  • The error message suggests that the issue is related to the usage of FunctionalTensor without a corresponding FunctionalTensorMode.
  • The branch function is using torch.tensor(0), which may be causing the issue.
  • To fix the issue, try modifying the branch function to avoid using torch.tensor directly.
  • Check the documentation for torch.cond and torch.tensor to ensure that they are being used correctly.

Example

def branch(x):
    # Avoid using torch.tensor directly
    return x.clone()

Notes

  • The issue is likely specific to the PyTorch version being used (2.12.0.dev20260414).
  • The error message suggests that the issue is related to the usage of FunctionalTensor without a corresponding FunctionalTensorMode, but the root cause may be more complex.

Recommendation

Apply workaround: Modify the branch function to avoid using torch.tensor directly, as shown in the example above. This may resolve the issue, but further investigation may be needed to determine the root cause.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix `torch.cond` with tensors created in `[true/false]_fn` fails during `run_decompositions` [1 pull requests, 1 comments, 2 participants]