pytorch - 💡(How to fix) Fix [inductor]`torch.cond` Branch with sliced view and scatter_add crashes with non-tensor input

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

import traceback import torch

def fn(): x = torch.arange(24, dtype=torch.float32).reshape(4, 6) pred = torch.tensor(True)

def true_fn(t):
    v = t[:, ::2]
    index = torch.tensor(
        [[0, 0, 1], [1, 1, 2], [2, 2, 0], [0, 0, 2]],
        dtype=torch.long,
    )
    return v.scatter_add(1, index, torch.ones(4, 3))

def false_fn(t):
    return t[:, 1::2] - 3.0

y = torch.cond(pred, true_fn, false_fn, (x,))
return y, y.sum(1), y.argmax(1)

eager = fn() print("eager") for t in eager: print(t)

print("compiled") try: compiled = torch.compile(fn, backend="inductor", dynamic=False) got = compiled() for t in got: print(t) print("same", all(torch.equal(a, b) for a, b in zip(eager, got))) except Exception: traceback.print_exc(limit=20)

Code Example

import traceback
import torch

def fn():
    x = torch.arange(24, dtype=torch.float32).reshape(4, 6)
    pred = torch.tensor(True)

    def true_fn(t):
        v = t[:, ::2]
        index = torch.tensor(
            [[0, 0, 1], [1, 1, 2], [2, 2, 0], [0, 0, 2]],
            dtype=torch.long,
        )
        return v.scatter_add(1, index, torch.ones(4, 3))

    def false_fn(t):
        return t[:, 1::2] - 3.0

    y = torch.cond(pred, true_fn, false_fn, (x,))
    return y, y.sum(1), y.argmax(1)

eager = fn()
print("eager")
for t in eager:
    print(t)

print("compiled")
try:
    compiled = torch.compile(fn, backend="inductor", dynamic=False)
    got = compiled()
    for t in got:
        print(t)
    print("same", all(torch.equal(a, b) for a, b in zip(eager, got)))
except Exception:
    traceback.print_exc(limit=20)

---

eager
tensor([[ 2.,  3.,  4.],
        [ 6., 10., 11.],
        [13., 14., 18.],
        [20., 20., 23.]])
tensor([ 9., 27., 45., 63.])
tensor([2, 2, 2, 2])
compiled
Traceback (most recent call last):
  File "/tmp/ipykernel_4136/3733048998.py", line 30, in <cell line: 0>
    got = compiled()
          ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1143, in compile_wrapper
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_4136/3733048998.py", line 4, in fn
    def fn():
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1421, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1277, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1142, in runtime_wrapper
    result = _codegen_runtime_wrapper(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(runtime_wrapper_orchestration)", line 8, in _runtime_wrapper
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 763, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_root/yb/cybwynungq6yraqiy4slahqe33rqwextjnunllb5rr747yjwqdoi.py", line 278, in call
    buf2 = true_graph_0(true_graph_0_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_root/yb/cybwynungq6yraqiy4slahqe33rqwextjnunllb5rr747yjwqdoi.py", line 212, in true_graph_0
    cpp_fused_arange_lift_fresh_ones_scatter_add_slice_view_1(true_graph_0_arg0_1, true_graph_0__tensor_constant0, true_graph_0_buf0)
RuntimeError: _torchinductor_pyobject_tensor_data_ptr: non-tensor input
Exception raised from _torchinductor_pyobject_tensor_data_ptr at /__w/pytorch/pytorch/torch/csrc/dynamo/guards.cpp:7266 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x7d2f710e576d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x69 (0x7d2f71072423 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x8b9f17 (0x7d2f6335ef17 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0x2cc7 (0x7d2f446d1cc7 in /tmp/torchinductor_root/l2/cl2epb7i24yjouz5knd5mawk6fjzu4s64gheyocpzcr24n56saqn.main.so)
frame #4: /usr/bin/python3() [0x56cc4f]
frame #5: _PyObject_MakeTpCall + 0x2fb (0x53f2ab in /usr/bin/python3)
frame #6: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #7: _PyObject_FastCallDictTstate + 0x1d8 (0x541b48 in /usr/bin/python3)
frame #8: _PyObject_Call_Prepend + 0x59 (0x57e2e9 in /usr/bin/python3)
frame #9: /usr/bin/python3() [0x6690fd]
frame #10: _PyObject_MakeTpCall + 0x2fb (0x53f2ab in /usr/bin/python3)
frame #11: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #12: <unknown function> + 0x89d511 (0x7d2f63342511 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x89feac (0x7d2f63344eac in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #14: PyObject_Vectorcall + 0x36 (0x562386 in /usr/bin/python3)
frame #15: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #16: dynamo_eval_custom_code + 0x1ce (0x7d2f6334234e in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #17: <unknown function> + 0x89eff0 (0x7d2f63343ff0 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #18: <unknown function> + 0x8a02d3 (0x7d2f633452d3 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #19: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #20: PyEval_EvalCode + 0x99 (0x61f469 in /usr/bin/python3)
frame #21: /usr/bin/python3() [0x63dbec]
frame #22: _PyEval_EvalFrameDefault + 0x3968 (0x54d248 in /usr/bin/python3)
frame #23: /usr/bin/python3() [0x635608]
frame #24: /usr/bin/python3() [0x63696c]
frame #25: _PyEval_EvalFrameDefault + 0x4361 (0x54dc41 in /usr/bin/python3)
frame #26: /usr/bin/python3() [0x59961d]
frame #27: /usr/bin/python3() [0x5991ae]
frame #28: _PyObject_Call + 0xed (0x580f8d in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #30: /usr/bin/python3() [0x635608]
frame #31: <unknown function> + 0x831f (0x7d2f8b41b31f in /usr/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #32: <unknown function> + 0x9014 (0x7d2f8b41c014 in /usr/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #33: /usr/bin/python3() [0x56c6bd]
frame #34: /usr/bin/python3() [0x64aa05]
frame #35: /usr/bin/python3() [0x4f73af]
frame #36: /usr/bin/python3() [0x56248c]
frame #37: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #38: PyEval_EvalCode + 0x99 (0x61f469 in /usr/bin/python3)
frame #39: /usr/bin/python3() [0x63dbec]
frame #40: /usr/bin/python3() [0x56248c]
frame #41: PyObject_Vectorcall + 0x36 (0x562386 in /usr/bin/python3)
frame #42: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #43: /usr/bin/python3() [0x64f860]
frame #44: Py_RunMain + 0x1f9 (0x64f109 in /usr/bin/python3)
frame #45: Py_BytesMain + 0x2d (0x6082cd in /usr/bin/python3)
frame #46: <unknown function> + 0x29d90 (0x7d2f8b9cdd90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #47: __libc_start_main + 0x80 (0x7d2f8b9cde40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #48: _start + 0x25 (0x608145 in /usr/bin/python3)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.cond works in eager mode when the true branch takes a strided sliced view (t[:, ::2]) and applies scatter_add.

However, compiling the same function with torch.compile(..., backend="inductor", dynamic=False) crashes at runtime inside the generated Inductor code.

The failure happens in the generated true-branch fused kernel and raises:

RuntimeError: _torchinductor_pyobject_tensor_data_ptr: non-tensor input

This looks like an Inductor codegen/runtime argument binding issue for a torch.cond branch involving a sliced view and scatter_add.

Expected behavior: compiled execution should return the same tensors as eager.

Actual behavior: Inductor crashes at runtime.

import traceback
import torch

def fn():
    x = torch.arange(24, dtype=torch.float32).reshape(4, 6)
    pred = torch.tensor(True)

    def true_fn(t):
        v = t[:, ::2]
        index = torch.tensor(
            [[0, 0, 1], [1, 1, 2], [2, 2, 0], [0, 0, 2]],
            dtype=torch.long,
        )
        return v.scatter_add(1, index, torch.ones(4, 3))

    def false_fn(t):
        return t[:, 1::2] - 3.0

    y = torch.cond(pred, true_fn, false_fn, (x,))
    return y, y.sum(1), y.argmax(1)

eager = fn()
print("eager")
for t in eager:
    print(t)

print("compiled")
try:
    compiled = torch.compile(fn, backend="inductor", dynamic=False)
    got = compiled()
    for t in got:
        print(t)
    print("same", all(torch.equal(a, b) for a, b in zip(eager, got)))
except Exception:
    traceback.print_exc(limit=20)

Output:

eager
tensor([[ 2.,  3.,  4.],
        [ 6., 10., 11.],
        [13., 14., 18.],
        [20., 20., 23.]])
tensor([ 9., 27., 45., 63.])
tensor([2, 2, 2, 2])
compiled
Traceback (most recent call last):
  File "/tmp/ipykernel_4136/3733048998.py", line 30, in <cell line: 0>
    got = compiled()
          ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1143, in compile_wrapper
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_4136/3733048998.py", line 4, in fn
    def fn():
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 1421, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1277, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1142, in runtime_wrapper
    result = _codegen_runtime_wrapper(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(runtime_wrapper_orchestration)", line 8, in _runtime_wrapper
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 763, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_root/yb/cybwynungq6yraqiy4slahqe33rqwextjnunllb5rr747yjwqdoi.py", line 278, in call
    buf2 = true_graph_0(true_graph_0_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/torchinductor_root/yb/cybwynungq6yraqiy4slahqe33rqwextjnunllb5rr747yjwqdoi.py", line 212, in true_graph_0
    cpp_fused_arange_lift_fresh_ones_scatter_add_slice_view_1(true_graph_0_arg0_1, true_graph_0__tensor_constant0, true_graph_0_buf0)
RuntimeError: _torchinductor_pyobject_tensor_data_ptr: non-tensor input
Exception raised from _torchinductor_pyobject_tensor_data_ptr at /__w/pytorch/pytorch/torch/csrc/dynamo/guards.cpp:7266 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x9d (0x7d2f710e576d in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x69 (0x7d2f71072423 in /usr/local/lib/python3.12/dist-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x8b9f17 (0x7d2f6335ef17 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #3: <unknown function> + 0x2cc7 (0x7d2f446d1cc7 in /tmp/torchinductor_root/l2/cl2epb7i24yjouz5knd5mawk6fjzu4s64gheyocpzcr24n56saqn.main.so)
frame #4: /usr/bin/python3() [0x56cc4f]
frame #5: _PyObject_MakeTpCall + 0x2fb (0x53f2ab in /usr/bin/python3)
frame #6: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #7: _PyObject_FastCallDictTstate + 0x1d8 (0x541b48 in /usr/bin/python3)
frame #8: _PyObject_Call_Prepend + 0x59 (0x57e2e9 in /usr/bin/python3)
frame #9: /usr/bin/python3() [0x6690fd]
frame #10: _PyObject_MakeTpCall + 0x2fb (0x53f2ab in /usr/bin/python3)
frame #11: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #12: <unknown function> + 0x89d511 (0x7d2f63342511 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x89feac (0x7d2f63344eac in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #14: PyObject_Vectorcall + 0x36 (0x562386 in /usr/bin/python3)
frame #15: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #16: dynamo_eval_custom_code + 0x1ce (0x7d2f6334234e in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #17: <unknown function> + 0x89eff0 (0x7d2f63343ff0 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #18: <unknown function> + 0x8a02d3 (0x7d2f633452d3 in /usr/local/lib/python3.12/dist-packages/torch/lib/libtorch_python.so)
frame #19: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #20: PyEval_EvalCode + 0x99 (0x61f469 in /usr/bin/python3)
frame #21: /usr/bin/python3() [0x63dbec]
frame #22: _PyEval_EvalFrameDefault + 0x3968 (0x54d248 in /usr/bin/python3)
frame #23: /usr/bin/python3() [0x635608]
frame #24: /usr/bin/python3() [0x63696c]
frame #25: _PyEval_EvalFrameDefault + 0x4361 (0x54dc41 in /usr/bin/python3)
frame #26: /usr/bin/python3() [0x59961d]
frame #27: /usr/bin/python3() [0x5991ae]
frame #28: _PyObject_Call + 0xed (0x580f8d in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #30: /usr/bin/python3() [0x635608]
frame #31: <unknown function> + 0x831f (0x7d2f8b41b31f in /usr/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #32: <unknown function> + 0x9014 (0x7d2f8b41c014 in /usr/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
frame #33: /usr/bin/python3() [0x56c6bd]
frame #34: /usr/bin/python3() [0x64aa05]
frame #35: /usr/bin/python3() [0x4f73af]
frame #36: /usr/bin/python3() [0x56248c]
frame #37: _PyEval_EvalFrameDefault + 0x4fd7 (0x54e8b7 in /usr/bin/python3)
frame #38: PyEval_EvalCode + 0x99 (0x61f469 in /usr/bin/python3)
frame #39: /usr/bin/python3() [0x63dbec]
frame #40: /usr/bin/python3() [0x56248c]
frame #41: PyObject_Vectorcall + 0x36 (0x562386 in /usr/bin/python3)
frame #42: _PyEval_EvalFrameDefault + 0x700 (0x549fe0 in /usr/bin/python3)
frame #43: /usr/bin/python3() [0x64f860]
frame #44: Py_RunMain + 0x1f9 (0x64f109 in /usr/bin/python3)
frame #45: Py_BytesMain + 0x2d (0x6082cd in /usr/bin/python3)
frame #46: <unknown function> + 0x29d90 (0x7d2f8b9cdd90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #47: __libc_start_main + 0x80 (0x7d2f8b9cde40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #48: _start + 0x25 (0x608145 in /usr/bin/python3)

Versions

PyTorch version: 2.10.0+cpu

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING