pytorch - 💡(How to fix) Fix torch.while_loop()'s body_fn is not allowed to return nothing [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182036Fetched 2026-05-01 05:32:42
View on GitHub
Comments
0
Participants
1
Timeline
27
Reactions
0
Author
Participants
Timeline (top)
mentioned ×11subscribed ×11labeled ×5

Error Message

Traceback (most recent call last): File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 28, in <module> out = torch.compile(mod, backend="inductor", fullgraph=True)() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 509, in call return super().call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 1112, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler raise BackendCompilerFailed( File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in call compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/init.py", line 2482, in call return compile_fx( ^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx return _maybe_wrap_and_compile_fx_main( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main return _compile_fx_main( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main return dynamo_common.aot_autograd( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/backends/common.py", line 123, in call cg = aot_module_simplified(gm, example_inputs, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified aot_state = create_aot_state( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state fw_metadata = run_functionalized_fw_and_collect_metadata( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner flat_f_outs = f(*flat_f_args) ^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call out = PropagateUnbackedSymInts(mod).run(*args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 197, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/experimental/symbolic_shapes.py", line 8548, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 294, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 377, in call_function return target(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in call return super().call( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 386, in dispatch return kernel(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 336, in maybe_run_autograd return self(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in call return super().call( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 422, in dispatch result = handler(mode, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 631, in while_loop_func return do_auto_functionalize_v2( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 694, in do_auto_functionalize_v2 return _do_auto_functionalize_v2_for_generic_mutable_operator( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 892, in _do_auto_functionalize_v2_for_generic_mutable_operator raise AssertionError( torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: AssertionError: hop is expected to return at least one output while_loop(Any cond_fn, Any body_fn, Tensor(a2!) additional_input0, Tensor(a3!) additional_input1) -> ().

While executing %while_loop : [num_users=0] = call_function[target=torch.ops.higher_order.while_loop](args = (%cond_fn_0, %body_fn_0, (), (%l_self_buffers_counter_, %l_self_buffers_buf_)), kwargs = {mutated_arg_indices: 0,1}) Original traceback: File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 23, in forward return while_loop(cond_fn, body_fn, ())

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Fix Action

Fix / Workaround

mod = M().to("cuda") with torch.no_grad(): out = torch.compile(mod, backend="inductor", fullgraph=True)()

error:

Traceback (most recent call last): File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 28, in <module> out = torch.compile(mod, backend="inductor", fullgraph=True)() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 509, in call return super().call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 1112, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler raise BackendCompilerFailed( File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in call compiled_gm = compiler_fn(gm, example_inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/init.py", line 2482, in call return compile_fx( ^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx return _maybe_wrap_and_compile_fx_main( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main return _compile_fx_main( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main return dynamo_common.aot_autograd( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_dynamo/backends/common.py", line 123, in call cg = aot_module_simplified(gm, example_inputs, **self.kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified aot_state = create_aot_state( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state fw_metadata = run_functionalized_fw_and_collect_metadata( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner flat_f_outs = f(*flat_f_args) ^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call out = PropagateUnbackedSymInts(mod).run(*args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 197, in run self.env[node] = self.run_node(node) ^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/experimental/symbolic_shapes.py", line 8548, in run_node result = super().run_node(n) ^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 294, in run_node return getattr(self, n.op)(n.target, args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 377, in call_function return target(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in call return super().call( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 386, in dispatch return kernel(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 336, in maybe_run_autograd return self(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in call return super().call( ^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in call return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 422, in dispatch result = handler(mode, *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn return fn(PythonFunctionalizeAPI(mode), *args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 631, in while_loop_func return do_auto_functionalize_v2( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 694, in do_auto_functionalize_v2 return _do_auto_functionalize_v2_for_generic_mutable_operator( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 892, in _do_auto_functionalize_v2_for_generic_mutable_operator raise AssertionError( torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: AssertionError: hop is expected to return at least one output while_loop(Any cond_fn, Any body_fn, Tensor(a2!) additional_input0, Tensor(a3!) additional_input1) -> ().

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 32 On-line CPU(s) list: 0-31 Vendor ID: AuthenticAMD BIOS Vendor ID: Advanced Micro Devices, Inc. Model name: AMD Ryzen 9 7950X 16-Core Processor BIOS Model name: AMD Ryzen 9 7950X 16-Core Processor Unknown CPU @ 4.5GHz BIOS CPU family: 107 CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 16 Socket(s): 1 Stepping: 2 CPU(s) scaling MHz: 38% CPU max MHz: 5881.0000 CPU min MHz: 400.0000 BogoMIPS: 8999.67 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d ibpb_exit_to_user Virtualization: AMD-V L1d cache: 512 KiB (16 instances) L1i cache: 512 KiB (16 instances) L2 cache: 16 MiB (16 instances) L3 cache: 64 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-31 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Code Example

import torch

from torch._higher_order_ops.while_loop import while_loop


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("counter", torch.tensor(0, device="cuda"))
        self.register_buffer(
            "buf", torch.ones(8, requires_grad=False, device="cuda")
        )

    def forward(self):
        def cond_fn():
            self.counter.add_(1)
            return self.counter < 3

        def body_fn():
            self.buf.add_(-1)
            return ()

        return while_loop(cond_fn, body_fn, ())


mod = M().to("cuda")
with torch.no_grad():
    out = torch.compile(mod, backend="inductor", fullgraph=True)()

---

Traceback (most recent call last):
  File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 28, in <module>
    out = torch.compile(mod, backend="inductor", fullgraph=True)()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 509, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 1112, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/__init__.py", line 2482, in __call__
    return compile_fx(
           ^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main
    return dynamo_common.aot_autograd(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified
    aot_state = create_aot_state(
                ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 197, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/experimental/symbolic_shapes.py", line 8548, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 294, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 377, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 386, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 336, in maybe_run_autograd
    return self(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 422, in dispatch
    result = handler(mode, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 631, in while_loop_func
    return do_auto_functionalize_v2(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 694, in do_auto_functionalize_v2
    return _do_auto_functionalize_v2_for_generic_mutable_operator(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 892, in _do_auto_functionalize_v2_for_generic_mutable_operator
    raise AssertionError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: hop is expected to return at least one output while_loop(Any cond_fn, Any body_fn, Tensor(a2!) additional_input0, Tensor(a3!) additional_input1) -> ().

While executing %while_loop : [num_users=0] = call_function[target=torch.ops.higher_order.while_loop](args = (%cond_fn_0, %body_fn_0, (), (%l_self_buffers_counter_, %l_self_buffers_buf_)), kwargs = {mutated_arg_indices: 0,1})
Original traceback:
  File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 23, in forward
    return while_loop(cond_fn, body_fn, ())

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Similar to #181891 , I tested torch.while_loop() and confirmed it has a similar limitation: if the body_fn returns nothing, it is not supported and leads to failure.

Repro:

import torch

from torch._higher_order_ops.while_loop import while_loop


class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("counter", torch.tensor(0, device="cuda"))
        self.register_buffer(
            "buf", torch.ones(8, requires_grad=False, device="cuda")
        )

    def forward(self):
        def cond_fn():
            self.counter.add_(1)
            return self.counter < 3

        def body_fn():
            self.buf.add_(-1)
            return ()

        return while_loop(cond_fn, body_fn, ())


mod = M().to("cuda")
with torch.no_grad():
    out = torch.compile(mod, backend="inductor", fullgraph=True)()

error:

Traceback (most recent call last):
  File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 28, in <module>
    out = torch.compile(mod, backend="inductor", fullgraph=True)()
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 509, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/nn/modules/module.py", line 1789, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/eval_frame.py", line 1112, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2983, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/opt/pytorch/pytorch/torch/_dynamo/output_graph.py", line 2958, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/__init__.py", line 2482, in __call__
    return compile_fx(
           ^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2743, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 2824, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_inductor/compile_fx.py", line 3022, in _compile_fx_main
    return dynamo_common.aot_autograd(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 1223, in aot_module_simplified
    aot_state = create_aot_state(
                ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/aot_autograd.py", line 582, in create_aot_state
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 220, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
    out = PropagateUnbackedSymInts(mod).run(*args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 197, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/experimental/symbolic_shapes.py", line 8548, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 294, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/fx/interpreter.py", line 377, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 386, in dispatch
    return kernel(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 336, in maybe_run_autograd
    return self(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 64, in __call__
    return super().__call__(
           ^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 539, in __call__
    return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 422, in dispatch
    result = handler(mode, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_ops.py", line 193, in functionalize_dispatch_mode_fn
    return fn(PythonFunctionalizeAPI(mode), *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/while_loop.py", line 631, in while_loop_func
    return do_auto_functionalize_v2(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 694, in do_auto_functionalize_v2
    return _do_auto_functionalize_v2_for_generic_mutable_operator(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/pytorch/torch/_higher_order_ops/auto_functionalize.py", line 892, in _do_auto_functionalize_v2_for_generic_mutable_operator
    raise AssertionError(
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: hop is expected to return at least one output while_loop(Any cond_fn, Any body_fn, Tensor(a2!) additional_input0, Tensor(a3!) additional_input1) -> ().

While executing %while_loop : [num_users=0] = call_function[target=torch.ops.higher_order.while_loop](args = (%cond_fn_0, %body_fn_0, (), (%l_self_buffers_counter_, %l_self_buffers_buf_)), kwargs = {mutated_arg_indices: 0,1})
Original traceback:
  File "/opt/pytorch/pytorch/agent_space/while_loop_noreturn1.py", line 23, in forward
    return while_loop(cond_fn, body_fn, ())

Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Hi @ydwu4 @galv @kshitij12345, could you help confirm this case is what we want to support after introducing mutation, since currently there are some basic assumptions in the codebase, e.g.: https://github.com/pytorch/pytorch/blob/718351d7f94076d8db38936f40855a6f69251e39/torch/_higher_order_ops/auto_functionalize.py#L890-L894 and body_fn is expected to return a tuple—so when there is no return it returns empty tuple https://github.com/pytorch/pytorch/blob/718351d7f94076d8db38936f40855a6f69251e39/torch/_higher_order_ops/while_loop.py#L161

Versions

Collecting environment information... PyTorch version: 2.13.0a0+git7a9a0d3 Is debug build: False CUDA used to build PyTorch: 13.2 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: Could not collect CMake version: version 3.31.6 Libc version: glibc-2.39

Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime) Python platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 13.2.78 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX 6000 Ada Generation GPU 1: NVIDIA RTX 6000 Ada Generation

Nvidia driver version: 535.288.01 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_tensor_ir.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_ext.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.22.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.22.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 32 On-line CPU(s) list: 0-31 Vendor ID: AuthenticAMD BIOS Vendor ID: Advanced Micro Devices, Inc. Model name: AMD Ryzen 9 7950X 16-Core Processor BIOS Model name: AMD Ryzen 9 7950X 16-Core Processor Unknown CPU @ 4.5GHz BIOS CPU family: 107 CPU family: 25 Model: 97 Thread(s) per core: 2 Core(s) per socket: 16 Socket(s): 1 Stepping: 2 CPU(s) scaling MHz: 38% CPU max MHz: 5881.0000 CPU min MHz: 400.0000 BogoMIPS: 8999.67 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid overflow_recov succor smca fsrm flush_l1d ibpb_exit_to_user Virtualization: AMD-V L1d cache: 512 KiB (16 instances) L1i cache: 512 KiB (16 instances) L2 cache: 16 MiB (16 instances) L3 cache: 64 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-31 Vulnerability Gather data sampling: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Versions of relevant libraries: [pip3] intel-openmp==2021.4.0 [pip3] mkl==2021.1.1 [pip3] mkl-devel==2021.1.1 [pip3] mkl-include==2021.1.1 [pip3] mypy==1.16.0 [pip3] mypy_extensions==1.1.0 [pip3] numpy==1.26.4 [pip3] nvidia-cudnn-frontend==1.23.0 [pip3] onnx==1.21.0 [pip3] onnx-ir==0.1.16 [pip3] onnxscript==0.6.2 [pip3] optree==0.13.0 [pip3] tbb==2021.13.1 [pip3] torch==2.13.0a0+git7a9a0d3 [pip3] torch_c_dlpack_ext==0.1.5 [pip3] triton==3.7.0+git88b227e2 [conda] Could not collect

cc @chauhang @penguinwu @ydwu4 @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The issue can be resolved by modifying the body_fn to return a tuple, as the current implementation expects at least one output.

Guidance

  • Modify the body_fn to return a tuple, even if it's empty, to satisfy the expectation of the while_loop function.
  • Verify that the body_fn is correctly returning a tuple by adding a print statement or a debugger.
  • Check the PyTorch documentation for any updates on the while_loop function and its expectations.
  • Consider filing a bug report or feature request to PyTorch if the current behavior is not intended.

Example

def body_fn():
    self.buf.add_(-1)
    return (None,)  # Return a tuple with a single None value

Notes

The issue is caused by the body_fn not returning a tuple, which is expected by the while_loop function. The current implementation of while_loop assumes that the body_fn will return at least one output. By modifying the body_fn to return a tuple, we can satisfy this expectation and resolve the issue.

Recommendation

Apply workaround: Modify the body_fn to return a tuple, as shown in the example above. This will allow the code to run without errors, but it may not be the intended behavior. It's recommended to file a bug report or feature request to PyTorch to clarify the expected behavior of the while_loop function.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.while_loop()'s body_fn is not allowed to return nothing [1 participants]