pytorch - 💡(How to fix) Fix [Inductor] Assertion error when FallbackKernel.create lowers functional custom op with dynamic shapes to .out variant

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

Error logs

Traceback (most recent call last):

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 192 On-line CPU(s) list: 0-191 Vendor ID: AuthenticAMD Model name: AMD EPYC 9454 48-Core Processor CPU family: 25 Model: 17 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 Stepping: 1 BogoMIPS: 5499.84 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d Virtualization: AMD-V L1d cache: 3 MiB (96 instances) L1i cache: 3 MiB (96 instances) L2 cache: 96 MiB (96 instances) L3 cache: 512 MiB (16 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Code Example

import os
import shutil

import torch
from torch._inductor.lowering import make_fallback

_LIB = "_inductor_custom_bmm_repro"

lib_def = torch.library.Library(_LIB, "DEF")
lib_def.define("custom_bmm(Tensor self, Tensor mat2) -> Tensor")
lib_def.define(
    "custom_bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> ()",
    tags=torch.Tag.out_variant,
)
cpu = torch.library.Library(_LIB, "IMPL", "CPU")
cpu.impl("custom_bmm", lambda s, m: torch.bmm(s, m))


def _custom_bmm_out(s, m, out):
    torch.bmm(s, m, out=out)


cpu.impl("custom_bmm.out", _custom_bmm_out)
meta = torch.library.Library(_LIB, "IMPL", "Meta")
meta.impl("custom_bmm", lambda s, m: s.new_empty(s.size(0), s.size(1), m.size(-1)))
meta.impl("custom_bmm.out", lambda s, m, out: None)
make_fallback(torch.ops._inductor_custom_bmm_repro.custom_bmm)
make_fallback(torch.ops._inductor_custom_bmm_repro.custom_bmm.out)


class CustomBmm(torch.nn.Module):
    def forward(self, a, b):
        return torch.ops._inductor_custom_bmm_repro.custom_bmm(a, b)


if __name__ == "__main__":
    shutil.rmtree(f"/tmp/torchinductor_{os.environ.get('USER', 'user')}", ignore_errors=True)
    torch._dynamo.reset()

    a = torch.randn(1200, 51, 64, dtype=torch.bfloat16)
    b = torch.randn(1200, 64, 51, dtype=torch.bfloat16)
    torch._dynamo.mark_dynamic(a, 1)
    torch._dynamo.mark_dynamic(b, 2)

    torch.compile(CustomBmm(), backend="inductor")(a, b)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Custom out variant ops registered with torch.Tag.out_variant and compiled with torch.compile(backend="inductor") hit AssertionError from lowerings for dynamic shapes.

FallbackKernel.create in torch/inductor/ir.py lowers a functional custom op to its .out variant and passes raw example_output.shape into FixedLayout. With mark_dynamic, shape entries are SymInt nodes that fail: assert all(isinstance(s, (Expr, int)) for s in size).

Trace log: https://gist.github.com/Rakul-Chauhan/86d49995391deba9e64a0970e10c9f6d

Minimal reproducer:

Run: python minimal_custom_op_inductor_assert.py """

import os
import shutil

import torch
from torch._inductor.lowering import make_fallback

_LIB = "_inductor_custom_bmm_repro"

lib_def = torch.library.Library(_LIB, "DEF")
lib_def.define("custom_bmm(Tensor self, Tensor mat2) -> Tensor")
lib_def.define(
    "custom_bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> ()",
    tags=torch.Tag.out_variant,
)
cpu = torch.library.Library(_LIB, "IMPL", "CPU")
cpu.impl("custom_bmm", lambda s, m: torch.bmm(s, m))


def _custom_bmm_out(s, m, out):
    torch.bmm(s, m, out=out)


cpu.impl("custom_bmm.out", _custom_bmm_out)
meta = torch.library.Library(_LIB, "IMPL", "Meta")
meta.impl("custom_bmm", lambda s, m: s.new_empty(s.size(0), s.size(1), m.size(-1)))
meta.impl("custom_bmm.out", lambda s, m, out: None)
make_fallback(torch.ops._inductor_custom_bmm_repro.custom_bmm)
make_fallback(torch.ops._inductor_custom_bmm_repro.custom_bmm.out)


class CustomBmm(torch.nn.Module):
    def forward(self, a, b):
        return torch.ops._inductor_custom_bmm_repro.custom_bmm(a, b)


if __name__ == "__main__":
    shutil.rmtree(f"/tmp/torchinductor_{os.environ.get('USER', 'user')}", ignore_errors=True)
    torch._dynamo.reset()

    a = torch.randn(1200, 51, 64, dtype=torch.bfloat16)
    b = torch.randn(1200, 64, 51, dtype=torch.bfloat16)
    torch._dynamo.mark_dynamic(a, 1)
    torch._dynamo.mark_dynamic(b, 2)

    torch.compile(CustomBmm(), backend="inductor")(a, b)

Error logs

Traceback (most recent call last): File "/proj/xcohdstaff8/rakchauh/minimal_custom_op_inductor_assert.py", line 55, in <module> torch.compile(CustomBmm(), backend="inductor")(a, b) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 473, in call return super().call(*args, **kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl return forward_call(*args, **kwargs) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 1062, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1069, in _compile_fx_inner raise InductorError(e, currentframe()).with_traceback( e.traceback ) from None File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1049, in _compile_fx_inner mb_compiled_graph = fx_codegen_and_compile( gm, ...<3 lines>... **graph_kwargs, ) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1836, in fx_codegen_and_compile return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/compile_fx.py", line 1512, in codegen_and_compile graph.run(*example_inputs) ~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/graph.py", line 1051, in run return super().run(*args) ~~~~~~~~~~~^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/fx/interpreter.py", line 197, in run self.env[node] = self.run_node(node) ~~~~~~~~~~~~~^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/graph.py", line 1902, in run_node result = super().run_node(n) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/fx/interpreter.py", line 294, in run_node return getattr(self, n.op)(n.target, args, kwargs) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/graph.py", line 1474, in call_function raise LoweringException( e, target, args, kwargs, stack_trace=stack_trace ).with_traceback(e.traceback) from None File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/graph.py", line 1451, in call_function out = lowerings[target](args, **kwargs) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/lowering.py", line 517, in wrapped out = decomp_fn(args, **kwargs) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/lowering.py", line 2401, in handler wrap_tensors, ir.FallbackKernel.create(kernel, args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/ir.py", line 8870, in create layout = FixedLayout( device=example_output.device, ...<2 lines>... stride=[example_output.stride()], ) File "/proj/rdi/staff/rakchauh/miniforge3/envs/bmm-bug-cloned_env/lib/python3.13/site-packages/torch/_inductor/ir.py", line 3895, in init assert all(isinstance(s, (Expr, int)) for s in size) ~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch._inductor.exc.InductorError: LoweringException: AssertionError: target: _inductor_custom_bmm_repro.custom_bmm.default args[0]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1200, s97, 64], stride=[64s97, 64, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1200, 64, s16], stride=[64s16, s16, 1])) ))AssertionError: target: _inductor_custom_bmm_repro.custom_bmm.default args[0]: TensorBox(StorageBox( InputBuffer(name='arg1_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1200, s97, 64], stride=[64s97, 64, 1])) )) args[1]: TensorBox(StorageBox( InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.bfloat16, size=[1200, 64, s16], stride=[64s16, s16, 1])) )) Found from : File "/proj/xcohdstaff8/rakchauh/minimal_custom_op_inductor_assert.py", line 43, in forward return torch.ops._inductor_custom_bmm_repro.custom_bmm(a, b)

Versions

Collecting environment information... PyTorch version: 2.12.0+cpu Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 4.3.2 Libc version: glibc-2.35

Python version: 3.13.13 | packaged by conda-forge | (main, Apr 8 2026, 02:00:33) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-5.15.0-131-generic-x86_64-with-glibc2.35 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 192 On-line CPU(s) list: 0-191 Vendor ID: AuthenticAMD Model name: AMD EPYC 9454 48-Core Processor CPU family: 25 Model: 17 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 Stepping: 1 BogoMIPS: 5499.84 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d Virtualization: AMD-V L1d cache: 3 MiB (96 instances) L1i cache: 3 MiB (96 instances) L2 cache: 96 MiB (96 instances) L3 cache: 512 MiB (16 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] flake8==7.3.0 [pip3] flake8-bugbear==25.11.29 [pip3] flake8-comprehensions==3.17.0 [pip3] flake8-executable==2.1.3 [pip3] flake8-logging-format==2024.24.12 [pip3] flake8-noqa==1.5.0 [pip3] flake8-pyi==26.5.0 [pip3] flake8_simplify==0.30.0 [pip3] mypy_extensions==1.1.0 [pip3] numpy==2.4.4 [pip3] onnx==1.21.0 [pip3] torch==2.12.0+cpu [pip3] torchao==0.17.0+cpu [pip3] torchaudio==2.11.0+cpu [pip3] torchinfo==1.8.0 [pip3] torchvision==0.27.0+cpu [conda] numpy 2.4.4 pypi_0 pypi [conda] torch 2.12.0+cpu pypi_0 pypi [conda] torchao 0.17.0+cpu pypi_0 pypi [conda] torchaudio 2.11.0+cpu pypi_0 pypi [conda] torchfix 0.7.0 pypi_0 pypi [conda] torchinfo 1.8.0 pypi_0 pypi [conda] torchvision 0.27.0+cpu pypi_0 pypi

cc @chauhang @penguinwu

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING