pytorch - 💡(How to fix) Fix `torch.compile(dynamic=True)` backward fails with `InductorError: CantSplit` on chained residual blocks scattering into a wider buffer

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

Torch version: 2.12.0+cu130 Testing eager mode... Eager mode OK Testing compiled mode... W0514 09:49:45.276000 2955 torch/_inductor/utils.py:1717] [0/0] Not enough SMs to use max_autotune_gemm mode

InductorError Traceback (most recent call last) /tmp/ipykernel_2955/4129170816.py in <cell line: 0>() 28 print("Eager mode OK") 29 print("Testing compiled mode...") ---> 30 torch.compile(M().to(dev), fullgraph=True, dynamic=True)(a, b).backward() 31 print("Compiled mode OK")

61 frames /usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py in call(self, *args, **kwargs) 471 ) 472 with _set_in_optimized_module(): --> 473 return super().call(*args, **kwargs) 474 475 def _aot_compile(self, inputs: list[torch._dynamo.aot_compile.ModelInput]) -> None:

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1776 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1777 else: -> 1778 return self._call_impl(*args, **kwargs) 1779 1780 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1787 or _global_backward_pre_hooks or _global_backward_hooks 1788 or _global_forward_hooks or _global_forward_pre_hooks): -> 1789 return forward_call(*args, **kwargs) 1790 1791 result = None

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py in compile_wrapper(*args, **kwargs) 1060 # Failures in the backend likely don't have useful 1061 # data in the TorchDynamo frames, so we strip them out. -> 1062 raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 1063 finally: 1064 # Restore the dynamic layer stack depth if necessary.

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py in compile_wrapper(*args, **kwargs) 1045 call_succeeded = False 1046 try: -> 1047 result = fn(*args, **kwargs) 1048 call_succeeded = True 1049 except (Unsupported, UncapturedHigherOrderOpError, UserError) as e:

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs) 1776 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1777 else: -> 1778 return self._call_impl(*args, **kwargs) 1779 1780 # torchrec tests the code consistency with the following code

/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs) 1787 or _global_backward_pre_hooks or _global_backward_hooks 1788 or _global_forward_hooks or _global_forward_pre_hooks): -> 1789 return forward_call(*args, **kwargs) 1790 1791 result = None

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in call(self, frame, cache_entry, frame_state) 2472 with compile_lock, _disable_current_modes(): 2473 # skip=1: skip this frame -> 2474 result = self._torchdynamo_orig_backend( 2475 frame, cache_entry, self.hooks, frame_state, skip=1 2476 )

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in call(self, frame, cache_entry, hooks, frame_state, skip) 734 ) 735 with compile_ctx, recompile_ctx: --> 736 result = _compile( 737 frame.f_code, 738 frame.f_globals,

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip, package, convert_frame_box) 1959 1960 try: -> 1961 guarded_code, tracer_output = compile_inner(code, one_graph, hooks) 1962 1963 # NB: We only put_code_state in success case. Success case here

/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py in wrapper_function(*args, **kwargs) 94 # in stack traces when profiling is not enabled. 95 if not StrobelightCompileTimeProfiler.enabled: ---> 96 return function(*args, **kwargs) 97 98 return StrobelightCompileTimeProfiler.profile_compile_time(

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in compile_inner(code, one_graph, hooks) 1569 stack.enter_context(CompileTimeInstructionCounter.record()) 1570 stack.enter_context(torch_function_mode_stack_state_mgr) -> 1571 result = _compile_inner(code, one_graph, hooks) 1572 assert torch._C._len_torch_function_stack() == 0, ( 1573 "Torch function mode stack state changed while dynamo tracing, please report a bug"

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in _compile_inner(code, one_graph, hooks) 1628 else contextlib.nullcontext() 1629 ): -> 1630 dynamo_output = compile_frame( 1631 code, 1632 globals,

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in compile_frame(code, globals, locals, builtins, closure, compiler_fn, one_graph, restart_reasons, export, export_constraints, frame_state, distributed_state, package) 1476 try: 1477 with dynamo_timed(f"compile_attempt_{attempt}", log_pt2_compile_event=True): -> 1478 bytecode, tracer_output = transform_code_object(code, transform) 1479 assert tracer_output is not None 1480 return DynamoOutput(

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py in transform_code_object(code, transformations, safe) 1624 propagate_line_nums(instructions) 1625 -> 1626 tracer_output = transformations(instructions, code_options) 1627 _, bytecode = clean_and_assemble_instructions(instructions, keys, code_options) 1628 return bytecode, tracer_output

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in transform(instructions, code_options) 1448 torch_function_mode_stack_state_mgr.stack 1449 ) -> 1450 tracer_output = trace_frame( 1451 code, 1452 globals,

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in _fn(*args, **kwargs) 341 reset_user_object_tracking() 342 try: --> 343 return fn(*args, **kwargs) 344 finally: 345 cleanup.close()

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in trace_frame(code, globals, locals, builtins, closure, compiler_fn, tf_mode_stack, one_graph, speculation_log, instructions, code_options, export, export_constraints, frame_state, distributed_state, package) 909 910 try: --> 911 run_tracer() 912 tracer_output = DynamoTracerOutput(tracer) 913 output = tracer_output.output_graph

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py in run_tracer() 890 tracer.output.mark_bytecode_tracing_start() 891 with tracing(tracer.output.tracing_context), tracer.set_current_tx(): --> 892 tracer.run() 893 except exc.UnspecializeRestartAnalysis: 894 speculation_log.clear() # type: ignore[has-type]

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py in run(self) 1811 self.start_point = self.instruction_pointer 1812 try: -> 1813 while self.step(): 1814 pass 1815 except Exception as e:

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py in step(self) 1478 1479 try: -> 1480 self.dispatch_table[inst.opcode](self, inst) 1481 return not self.output.should_exit 1482 except TensorifyScalarRestartAnalysis:

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py in RETURN_VALUE(self, inst) 5141 5142 def RETURN_VALUE(self, inst: Instruction) -> None: -> 5143 self._return(inst) 5144 5145 def RETURN_CONST(self, inst: Instruction) -> None:

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py in _return(self, inst) 5123 ) 5124 log.debug("return triggered compile") -> 5125 all_stack_locals_metadata = self.output.compile_subgraph( 5126 self, 5127 reason=GraphCompileReason(

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py in compile_subgraph(self, tx, reason, stack_pops) 2111 if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: 2112 output.extend( -> 2113 self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) 2114 ) 2115

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py in compile_and_call_fx_graph(self, tx, rv, root) 2726 2727 with self.restore_global_state(): -> 2728 compiled_fn = self.call_user_compiler(gm, self.example_inputs()) 2729 2730 from torch.fx._lazy_graph_module import _LazyGraphModule

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py in call_user_compiler(self, gm, example_inputs) 2893 dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us", 2894 ): -> 2895 return self._call_user_compiler(gm, example_inputs) 2896 2897 def _call_user_compiler(

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py in _call_user_compiler(self, gm, example_inputs) 2951 if config.verify_correctness: 2952 compiler_fn = WrapperBackend(compiler_fn) -> 2953 compiled_fn = compiler_fn(gm, example_inputs) 2954 _step_logger()(logging.INFO, f"done compiler function {name}") 2955 assert callable(compiled_fn), "compiler_fn did not return callable"

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py in call(self, gm, example_inputs, **kwargs) 154 raise 155 else: --> 156 compiled_gm = compiler_fn(gm, example_inputs) 157 158 return compiled_gm # type: ignore[return-value]

/usr/local/lib/python3.12/dist-packages/torch/init.py in call(self, model_, inputs_, config_patches) 2480 2481 all_patches = {**self.config, **(config_patches or {})} -> 2482 return compile_fx( 2483 model_, 2484 inputs_,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions, ignore_shape_env, compile_region_name) 2732 ) 2733 -> 2734 return maybe_wrap_and_compile_fx_main( 2735 model, 2736 example_inputs_,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in maybe_wrap_and_compile_fx_main(model, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name) 2813 2814 # Finally do the actual work! -> 2815 return compile_fx_main( 2816 model, 2817 example_inputs_,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in compile_fx_main(model, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name) 3026 # We will also shorten the traceback inside dynamo. 3027 # This is only useful if inductor is called directly with an FX graph. -> 3028 raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 3029 3030

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in compile_fx_main(model, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name) 3011 ): 3012 try: -> 3013 return dynamo_common.aot_autograd( 3014 fw_compiler=fw_compiler, 3015 bw_compiler=bw_compiler,

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/backends/common.py in call(self, gm, example_inputs, **kwargs) 121 # NB: NOT cloned! 122 with enable_aot_logging(), patch_config: --> 123 cg = aot_module_simplified(gm, example_inputs, **self.kwargs) 124 counters["aot_autograd"]["ok"] += 1 125 return disable(cg, reason="do not trace AOT-compiled graph")

/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, compiler_config_extra, ignore_shape_env, disable_functionalization, pre_grad_passes, compile_region_name) 1232 aot_state.fw_metadata.act_input_indices = act_input_indices 1233 aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) -> 1234 compiled_fn, _ = aot_stage2_compile( 1235 aot_state, 1236 aot_graph_capture,

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py in aot_stage2_compile(aot_state, aot_graph_capture, partition_fn, fw_compiler, bw_compiler, inference_compiler) 376 377 if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch: --> 378 return aot_stage2_autograd(aot_state, aot_graph_capture) 379 else: 380 return aot_stage2_inference(aot_state, aot_graph_capture)

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py in aot_stage2_autograd(aot_state, aot_graph_capture) 2284 ) 2285 -> 2286 fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile( 2287 fw_module, 2288 adjusted_flat_args,

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py in _aot_stage2b_fw_compile(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, num_fw_outs_saved_for_bw, aot_config) 2073 # pyrefly: ignore [implicit-any] 2074 ) -> tuple[list[tuple[int, ...] | None] | None, Callable]: -> 2075 return _aot_stage2b_compile_forward_or_inference( 2076 fw_module, 2077 adjusted_flat_args,

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py in _aot_stage2b_compile_forward_or_inference(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, aot_config, is_inference, num_fw_outs_saved_for_bw) 2601 with TracingContext.report_output_strides() as fwd_output_strides: 2602 # pyrefly: ignore[not-callable] -> 2603 compiled_fw_func = compiler(fw_module, adjusted_flat_args) 2604 2605 # Make boxed if needed

/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/schemas.py in call(self, gm, example_inputs) 1419 example_inputs: Sequence[InputType], 1420 ) -> OutputCode: -> 1421 output_code = self.compiler_fn(gm, example_inputs) 1422 return output_code 1423

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in fw_compiler_base(gm, example_inputs, is_inference) 2874 else: 2875 num_orig_model_outputs = get_num_model_outputs(gm) -> 2876 return compile_fx_forward( 2877 gm, 2878 example_inputs,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in compile_fx_forward(gm, example_inputs, num_orig_model_outputs, num_example_inputs, compiler_config_extra, inner_compile, is_inference) 2502 2503 with cudagraph_annotation_context(compiler_config_extra.cudagraphs): -> 2504 result = inner_compile( 2505 gm, 2506 example_inputs,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in compile_fx_inner(gm, example_inputs, compile_region_name, **kwargs) 825 is_backward=kwargs["is_backward"], 826 ) --> 827 return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( 828 gm, 829 example_inputs,

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py in debug_wrapper(gm, example_inputs, compile_region_name, **kwargs) 312 # Call the compiler_fn - which is either aot_autograd or inductor 313 # with fake inputs --> 314 inner_compiled_fn = compiler_fn(gm, example_inputs) 315 except Exception: 316 # TODO: Failures here are troublesome because no real inputs,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs) 1067 raise 1068 except Exception as e: -> 1069 raise InductorError(e, currentframe()).with_traceback( 1070 e.traceback 1071 ) from None

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs) 1047 TritonBundler.begin_compile() 1048 try: -> 1049 mb_compiled_graph = fx_codegen_and_compile( 1050 gm, 1051 example_inputs,

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, compile_region_name, **graph_kwargs) 1834 1835 # pyrefly: ignore [unbound-name] -> 1836 return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) 1837 1838

/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py in codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs) 1595 ) 1596 else: -> 1597 compiled_module = graph.compile_to_module() 1598 compiled_fn = compiled_module.call 1599 compiled_fn_runner = getattr(

/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py in compile_to_module(self) 2611 dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us", 2612 ): -> 2613 return self._compile_to_module() 2614 2615 def _compile_to_module(self) -> CompiledModule:

/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py in _compile_to_module(self) 2617 # returned separately in AOTInductor mode. 2618 wrapper_code, _ = ( -> 2619 self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() 2620 ) 2621

/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py in codegen(self) 2553 2554 self.wrapper_code.push_codegened_graph(self) -> 2555 self.scheduler.codegen() 2556 2557 log.debug(

/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py in codegen(self) 7348 with dynamo_timed("Scheduler.codegen"): 7349 return ( -> 7350 self._codegen_partitions() 7351 if torch._inductor.config.graph_partition 7352 else self._codegen(self.nodes)

/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py in _codegen_partitions(self) 7488 7489 if signature.skip_cudagraph: -> 7490 self._codegen(partition) 7491 else: 7492 self._codegen_partition_wrapper(partition, signature)

/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py in _codegen(self, nodes) 7639 elif isinstance(node, (FusedSchedulerNode, SchedulerNode)): 7640 # pyrefly: ignore [unbound-name] -> 7641 self.get_backend(device).codegen_node(node) 7642 else: 7643 assert isinstance(node, NopKernelSchedulerNode)

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py in codegen_node(self, node) 150 151 def codegen_node(self, node: FusedSchedulerNode | SchedulerNode) -> None: --> 152 return self._triton_scheduling.codegen_node(node) 153 154 def codegen_sync(self) -> None:

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in codegen_node(self, node) 1862 coalesce_analysis = None 1863 -> 1864 return self._codegen_nodes(nodes, coalesce_analysis) # type: ignore[arg-type] 1865 1866 @staticmethod

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in _codegen_nodes(self, nodes, coalesce_analysis) 1835 schedule_log.debug("Schedule:\n %s", node_schedule) 1836 -> 1837 return self.codegen_node_schedule( 1838 SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis) 1839 )

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in codegen_node_schedule(self, kernel_features) 1950 ) 1951 for kernel in kernels: -> 1952 self.codegen_node_schedule_with_kernel(node_schedule, kernel) 1953 MultiKernel.merge_workspaces_inplace(kernels) 1954

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in codegen_node_schedule_with_kernel(self, node_schedule, kernel) 2043 else: 2044 node.decide_inplace_update() -> 2045 index_vars = kernel.split_and_set_ranges(node.get_ranges()) 2046 all_indexing.update( 2047 dict.fromkeys(

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in split_and_set_ranges(self, lengths) 941 # Map the kernel's group structure to the node's sizes and set the ranges 942 # using the set_ranges method, returning the resulting iteration variables --> 943 return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) 944 945 @classmethod

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in map_kernel_groups_to_node_sizes(cls, groups, lengths, set_ranges) 968 return set_ranges(*lengths) 969 --> 970 new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths) 971 itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))] 972 return [[fn(itervars) for fn in fns] for fns in return_getters_groups]

/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py in _split_iteration_ranges(groups, lengths) 845 size, remaining[current_group] 846 ): --> 847 raise CantSplit(size, remaining[current_group]) 848 849 size1 = remaining[current_group]

InductorError: CantSplit: 128s52((s97//s52)) + 128*((s97//s52)) not divisible by s52*((s97//s52)) + ((s97//s52))

Root Cause

/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py in debug_wrapper(gm, example_inputs, compile_region_name, **kwargs) 312 # Call the compiler_fn - which is either aot_autograd or inductor 313 # with fake inputs --> 314 inner_compiled_fn = compiler_fn(gm, example_inputs) 315 except Exception: 316 # TODO: Failures here are troublesome because no real inputs,

Fix Action

Fix / Workaround

"""
Previous cell: !pip install --upgrade torch torchvision torchaudio
"""

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
   1478 
   1479         try:
-> 1480             self.dispatch_table[inst.opcode](self, inst)
   1481             return not self.output.should_exit
   1482         except TensorifyScalarRestartAnalysis:

[/usr/local/lib/python3.12/dist-packages/torch/__init__.py](https://localhost:8080/#) in __call__(self, model_, inputs_, config_patches)
   2480 
   2481         all_patches = {**self.config, **(config_patches or {})}
-> 2482         return compile_fx(
   2483             model_,
   2484             inputs_,

Code Example

"""
Previous cell: !pip install --upgrade torch torchvision torchaudio
"""

import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
import torch

print("Torch version:", torch.__version__) # Torch version: 2.12.0+cu130

class M(torch.nn.Module):
    def __init__(self, d=128):
        super().__init__()
        self.n1 = torch.nn.RMSNorm(d)
        self.l1 = torch.nn.Linear(d, d, bias=False)
        self.n2 = torch.nn.RMSNorm(d)
        self.l2 = torch.nn.Linear(d, d, bias=False)

    def forward(self, a, b):
        m = a.shape[0] // b.shape[0]
        out = torch.zeros(m * (b.shape[0] + 1), a.shape[-1], device=a.device)
        out[: a.shape[0]] = a
        out = out + self.l1(self.n1(out))
        out = out + self.l2(self.n2(out))
        return out.sum()

dev = "cuda"
a, b = torch.randn(21, 128, device=dev), torch.randn(7, device=dev)
print("Testing eager mode...")
M().to(dev)(a, b).backward()
print("Eager mode OK")
print("Testing compiled mode...")
torch.compile(M().to(dev), fullgraph=True, dynamic=True)(a, b).backward() # Crashes
print("Compiled mode OK")

---

Torch version: 2.12.0+cu130
Testing eager mode...
Eager mode OK
Testing compiled mode...
W0514 09:49:45.276000 2955 torch/_inductor/utils.py:1717] [0/0] Not enough SMs to use max_autotune_gemm mode
---------------------------------------------------------------------------
InductorError                             Traceback (most recent call last)
[/tmp/ipykernel_2955/4129170816.py](https://localhost:8080/#) in <cell line: 0>()
     28     print("Eager mode OK")
     29     print("Testing compiled mode...")
---> 30     torch.compile(M().to(dev), fullgraph=True, dynamic=True)(a, b).backward()
     31     print("Compiled mode OK")

61 frames
[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    471             )
    472         with _set_in_optimized_module():
--> 473             return super().__call__(*args, **kwargs)
    474 
    475     def _aot_compile(self, inputs: list[torch._dynamo.aot_compile.ModelInput]) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1776             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1777         else:
-> 1778             return self._call_impl(*args, **kwargs)
   1779 
   1780     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1787                 or _global_backward_pre_hooks or _global_backward_hooks
   1788                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1789             return forward_call(*args, **kwargs)
   1790 
   1791         result = None

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in compile_wrapper(*args, **kwargs)
   1060                     # Failures in the backend likely don't have useful
   1061                     # data in the TorchDynamo frames, so we strip them out.
-> 1062                     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
   1063                 finally:
   1064                     # Restore the dynamic layer stack depth if necessary.

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in compile_wrapper(*args, **kwargs)
   1045                 call_succeeded = False
   1046                 try:
-> 1047                     result = fn(*args, **kwargs)
   1048                     call_succeeded = True
   1049                 except (Unsupported, UncapturedHigherOrderOpError, UserError) as e:

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1776             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1777         else:
-> 1778             return self._call_impl(*args, **kwargs)
   1779 
   1780     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1787                 or _global_backward_pre_hooks or _global_backward_hooks
   1788                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1789             return forward_call(*args, **kwargs)
   1790 
   1791         result = None

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in __call__(self, frame, cache_entry, frame_state)
   2472         with compile_lock, _disable_current_modes():
   2473             # skip=1: skip this frame
-> 2474             result = self._torchdynamo_orig_backend(
   2475                 frame, cache_entry, self.hooks, frame_state, skip=1
   2476             )

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in __call__(self, frame, cache_entry, hooks, frame_state, skip)
    734             )
    735             with compile_ctx, recompile_ctx:
--> 736                 result = _compile(
    737                     frame.f_code,
    738                     frame.f_globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip, package, convert_frame_box)
   1959 
   1960         try:
-> 1961             guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
   1962 
   1963             # NB: We only put_code_state in success case.  Success case here

[/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py](https://localhost:8080/#) in wrapper_function(*args, **kwargs)
     94             # in stack traces when profiling is not enabled.
     95             if not StrobelightCompileTimeProfiler.enabled:
---> 96                 return function(*args, **kwargs)
     97 
     98             return StrobelightCompileTimeProfiler.profile_compile_time(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_inner(code, one_graph, hooks)
   1569             stack.enter_context(CompileTimeInstructionCounter.record())
   1570             stack.enter_context(torch_function_mode_stack_state_mgr)
-> 1571             result = _compile_inner(code, one_graph, hooks)
   1572             assert torch._C._len_torch_function_stack() == 0, (
   1573                 "Torch function mode stack state changed while dynamo tracing, please report a bug"

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile_inner(code, one_graph, hooks)
   1628                 else contextlib.nullcontext()
   1629             ):
-> 1630                 dynamo_output = compile_frame(
   1631                     code,
   1632                     globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_frame(code, globals, locals, builtins, closure, compiler_fn, one_graph, restart_reasons, export, export_constraints, frame_state, distributed_state, package)
   1476         try:
   1477             with dynamo_timed(f"compile_attempt_{attempt}", log_pt2_compile_event=True):
-> 1478                 bytecode, tracer_output = transform_code_object(code, transform)
   1479                 assert tracer_output is not None
   1480                 return DynamoOutput(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in transform_code_object(code, transformations, safe)
   1624     propagate_line_nums(instructions)
   1625 
-> 1626     tracer_output = transformations(instructions, code_options)
   1627     _, bytecode = clean_and_assemble_instructions(instructions, keys, code_options)
   1628     return bytecode, tracer_output

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in transform(instructions, code_options)
   1448             torch_function_mode_stack_state_mgr.stack
   1449         )
-> 1450         tracer_output = trace_frame(
   1451             code,
   1452             globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    341             reset_user_object_tracking()
    342             try:
--> 343                 return fn(*args, **kwargs)
    344             finally:
    345                 cleanup.close()

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in trace_frame(code, globals, locals, builtins, closure, compiler_fn, tf_mode_stack, one_graph, speculation_log, instructions, code_options, export, export_constraints, frame_state, distributed_state, package)
    909 
    910     try:
--> 911         run_tracer()
    912         tracer_output = DynamoTracerOutput(tracer)
    913         output = tracer_output.output_graph

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in run_tracer()
    890             tracer.output.mark_bytecode_tracing_start()
    891             with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 892                 tracer.run()
    893         except exc.UnspecializeRestartAnalysis:
    894             speculation_log.clear()  # type: ignore[has-type]

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
   1811                 self.start_point = self.instruction_pointer
   1812                 try:
-> 1813                     while self.step():
   1814                         pass
   1815                 except Exception as e:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
   1478 
   1479         try:
-> 1480             self.dispatch_table[inst.opcode](self, inst)
   1481             return not self.output.should_exit
   1482         except TensorifyScalarRestartAnalysis:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in RETURN_VALUE(self, inst)
   5141 
   5142     def RETURN_VALUE(self, inst: Instruction) -> None:
-> 5143         self._return(inst)
   5144 
   5145     def RETURN_CONST(self, inst: Instruction) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in _return(self, inst)
   5123         )
   5124         log.debug("return triggered compile")
-> 5125         all_stack_locals_metadata = self.output.compile_subgraph(
   5126             self,
   5127             reason=GraphCompileReason(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_subgraph(self, tx, reason, stack_pops)
   2111             if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
   2112                 output.extend(
-> 2113                     self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
   2114                 )
   2115 

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_and_call_fx_graph(self, tx, rv, root)
   2726 
   2727             with self.restore_global_state():
-> 2728                 compiled_fn = self.call_user_compiler(gm, self.example_inputs())
   2729 
   2730             from torch.fx._lazy_graph_module import _LazyGraphModule

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in call_user_compiler(self, gm, example_inputs)
   2893             dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
   2894         ):
-> 2895             return self._call_user_compiler(gm, example_inputs)
   2896 
   2897     def _call_user_compiler(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in _call_user_compiler(self, gm, example_inputs)
   2951             if config.verify_correctness:
   2952                 compiler_fn = WrapperBackend(compiler_fn)
-> 2953             compiled_fn = compiler_fn(gm, example_inputs)
   2954             _step_logger()(logging.INFO, f"done compiler function {name}")
   2955             assert callable(compiled_fn), "compiler_fn did not return callable"

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py](https://localhost:8080/#) in __call__(self, gm, example_inputs, **kwargs)
    154                     raise
    155         else:
--> 156             compiled_gm = compiler_fn(gm, example_inputs)
    157 
    158         return compiled_gm  # type: ignore[return-value]

[/usr/local/lib/python3.12/dist-packages/torch/__init__.py](https://localhost:8080/#) in __call__(self, model_, inputs_, config_patches)
   2480 
   2481         all_patches = {**self.config, **(config_patches or {})}
-> 2482         return compile_fx(
   2483             model_,
   2484             inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions, ignore_shape_env, compile_region_name)
   2732                 )
   2733 
-> 2734     return _maybe_wrap_and_compile_fx_main(
   2735         model_,
   2736         example_inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _maybe_wrap_and_compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   2813 
   2814     # Finally do the actual work!
-> 2815     return _compile_fx_main(
   2816         model_,
   2817         example_inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   3026                 # We will also shorten the traceback inside dynamo.
   3027                 # This is only useful if inductor is called directly with an FX graph.
-> 3028                 raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
   3029 
   3030 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   3011         ):
   3012             try:
-> 3013                 return dynamo_common.aot_autograd(
   3014                     fw_compiler=fw_compiler,
   3015                     bw_compiler=bw_compiler,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/backends/common.py](https://localhost:8080/#) in __call__(self, gm, example_inputs, **kwargs)
    121             # NB: NOT cloned!
    122             with enable_aot_logging(), patch_config:
--> 123                 cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    124                 counters["aot_autograd"]["ok"] += 1
    125                 return disable(cg, reason="do not trace AOT-compiled graph")

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, compiler_config_extra, ignore_shape_env, disable_functionalization, pre_grad_passes, compile_region_name)
   1232             aot_state.fw_metadata.act_input_indices = act_input_indices
   1233             aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call)
-> 1234             compiled_fn, _ = aot_stage2_compile(
   1235                 aot_state,
   1236                 aot_graph_capture,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in aot_stage2_compile(aot_state, aot_graph_capture, partition_fn, fw_compiler, bw_compiler, inference_compiler)
    376 
    377     if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch:
--> 378         return aot_stage2_autograd(aot_state, aot_graph_capture)
    379     else:
    380         return aot_stage2_inference(aot_state, aot_graph_capture)

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in aot_stage2_autograd(aot_state, aot_graph_capture)
   2284     )
   2285 
-> 2286     fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile(
   2287         fw_module,
   2288         adjusted_flat_args,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in _aot_stage2b_fw_compile(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, num_fw_outs_saved_for_bw, aot_config)
   2073     # pyrefly: ignore [implicit-any]
   2074 ) -> tuple[list[tuple[int, ...] | None] | None, Callable]:
-> 2075     return _aot_stage2b_compile_forward_or_inference(
   2076         fw_module,
   2077         adjusted_flat_args,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in _aot_stage2b_compile_forward_or_inference(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, aot_config, is_inference, num_fw_outs_saved_for_bw)
   2601         with TracingContext.report_output_strides() as fwd_output_strides:
   2602             # pyrefly: ignore[not-callable]
-> 2603             compiled_fw_func = compiler(fw_module, adjusted_flat_args)
   2604 
   2605         # Make boxed if needed

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/schemas.py](https://localhost:8080/#) in __call__(self, gm, example_inputs)
   1419         example_inputs: Sequence[InputType],
   1420     ) -> OutputCode:
-> 1421         output_code = self.compiler_fn(gm, example_inputs)
   1422         return output_code
   1423 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in fw_compiler_base(gm, example_inputs, is_inference)
   2874                 else:
   2875                     num_orig_model_outputs = get_num_model_outputs(gm)
-> 2876                 return compile_fx_forward(
   2877                     gm,
   2878                     example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx_forward(gm, example_inputs, num_orig_model_outputs, num_example_inputs, compiler_config_extra, inner_compile, is_inference)
   2502 
   2503     with cudagraph_annotation_context(compiler_config_extra.cudagraphs):
-> 2504         result = inner_compile(
   2505             gm,
   2506             example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx_inner(gm, example_inputs, compile_region_name, **kwargs)
    825             is_backward=kwargs["is_backward"],
    826         )
--> 827         return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
    828             gm,
    829             example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py](https://localhost:8080/#) in debug_wrapper(gm, example_inputs, compile_region_name, **kwargs)
    312             # Call the compiler_fn - which is either aot_autograd or inductor
    313             # with fake inputs
--> 314             inner_compiled_fn = compiler_fn(gm, example_inputs)
    315         except Exception:
    316             # TODO: Failures here are troublesome because no real inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs)
   1067                 raise
   1068             except Exception as e:
-> 1069                 raise InductorError(e, currentframe()).with_traceback(
   1070                     e.__traceback__
   1071                 ) from None

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs)
   1047             TritonBundler.begin_compile()
   1048             try:
-> 1049                 mb_compiled_graph = fx_codegen_and_compile(
   1050                     gm,
   1051                     example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, compile_region_name, **graph_kwargs)
   1834 
   1835     # pyrefly: ignore [unbound-name]
-> 1836     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
   1837 
   1838 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs)
   1595                                 )
   1596                         else:
-> 1597                             compiled_module = graph.compile_to_module()
   1598                             compiled_fn = compiled_module.call
   1599                             compiled_fn_runner = getattr(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in compile_to_module(self)
   2611             dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
   2612         ):
-> 2613             return self._compile_to_module()
   2614 
   2615     def _compile_to_module(self) -> CompiledModule:

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in _compile_to_module(self)
   2617         # returned separately in AOTInductor mode.
   2618         wrapper_code, _ = (
-> 2619             self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
   2620         )
   2621 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in codegen(self)
   2553 
   2554             self.wrapper_code.push_codegened_graph(self)
-> 2555             self.scheduler.codegen()
   2556 
   2557             log.debug(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in codegen(self)
   7348         with dynamo_timed("Scheduler.codegen"):
   7349             return (
-> 7350                 self._codegen_partitions()
   7351                 if torch._inductor.config.graph_partition
   7352                 else self._codegen(self.nodes)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in _codegen_partitions(self)
   7488 
   7489                 if signature.skip_cudagraph:
-> 7490                     self._codegen(partition)
   7491                 else:
   7492                     self._codegen_partition_wrapper(partition, signature)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in _codegen(self, nodes)
   7639             elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
   7640                 # pyrefly: ignore [unbound-name]
-> 7641                 self.get_backend(device).codegen_node(node)
   7642             else:
   7643                 assert isinstance(node, NopKernelSchedulerNode)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py](https://localhost:8080/#) in codegen_node(self, node)
    150 
    151     def codegen_node(self, node: FusedSchedulerNode | SchedulerNode) -> None:
--> 152         return self._triton_scheduling.codegen_node(node)
    153 
    154     def codegen_sync(self) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node(self, node)
   1862             coalesce_analysis = None
   1863 
-> 1864         return self._codegen_nodes(nodes, coalesce_analysis)  # type: ignore[arg-type]
   1865 
   1866     @staticmethod

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in _codegen_nodes(self, nodes, coalesce_analysis)
   1835         schedule_log.debug("Schedule:\n %s", node_schedule)
   1836 
-> 1837         return self.codegen_node_schedule(
   1838             SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis)
   1839         )

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node_schedule(self, kernel_features)
   1950         )
   1951         for kernel in kernels:
-> 1952             self.codegen_node_schedule_with_kernel(node_schedule, kernel)
   1953         MultiKernel.merge_workspaces_inplace(kernels)
   1954 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node_schedule_with_kernel(self, node_schedule, kernel)
   2043                 else:
   2044                     node.decide_inplace_update()
-> 2045                     index_vars = kernel.split_and_set_ranges(node.get_ranges())
   2046                     all_indexing.update(
   2047                         dict.fromkeys(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in split_and_set_ranges(self, lengths)
    941         # Map the kernel's group structure to the node's sizes and set the ranges
    942         # using the set_ranges method, returning the resulting iteration variables
--> 943         return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
    944 
    945     @classmethod

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in map_kernel_groups_to_node_sizes(cls, groups, lengths, set_ranges)
    968             return set_ranges(*lengths)
    969 
--> 970         new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
    971         itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
    972         return [[fn(itervars) for fn in fns] for fns in return_getters_groups]

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in _split_iteration_ranges(groups, lengths)
    845                         size, remaining[current_group]
    846                     ):
--> 847                         raise CantSplit(size, remaining[current_group])
    848 
    849                     size1 = remaining[current_group]

InductorError: CantSplit: 128*s52*((s97//s52)) + 128*((s97//s52)) not divisible by s52*((s97//s52)) + ((s97//s52))

---

Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.35

Python version: 3.12.13 (main, Mar  4 2026, 09:23:07) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.122+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 580.82.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  2
On-line CPU(s) list:                     0,1
Vendor ID:                               GenuineIntel
Model name:                              Intel(R) Xeon(R) CPU @ 2.00GHz
CPU family:                              6
Model:                                   85
Thread(s) per core:                      2
Core(s) per socket:                      1
Socket(s):                               1
Stepping:                                3
BogoMIPS:                                4000.44
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
Hypervisor vendor:                       KVM
Virtualization type:                     full
L1d cache:                               32 KiB (1 instance)
L1i cache:                               32 KiB (1 instance)
L2 cache:                                1 MiB (1 instance)
L3 cache:                                38.5 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0,1
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Vulnerable
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Mitigation; PTE Inversion
Vulnerability Mds:                       Vulnerable; SMT Host state unknown
Vulnerability Meltdown:                  Vulnerable
Vulnerability Mmio stale data:           Vulnerable
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Vulnerable
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Vulnerable
Vulnerability Spectre v1:                Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:                Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Vulnerable
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] intel-cmplr-lib-ur==2025.3.3
[pip3] intel-openmp==2025.3.3
[pip3] mkl==2025.3.1
[pip3] numpy==2.0.2
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx==13.0.85
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] nvtx==0.2.15
[pip3] onemkl-license==2025.3.1
[pip3] optree==0.19.0
[pip3] tbb==2022.3.1
[pip3] tcmlib==1.4.1
[pip3] torch==2.12.0
[pip3] torchao==0.10.0
[pip3] torchaudio==2.11.0
[pip3] torchcodec==0.10.0+cu128
[pip3] torchdata==0.11.0
[pip3] torchsummary==1.5.1
[pip3] torchtune==0.6.1
[pip3] torchvision==0.27.0
[pip3] triton==3.7.0
[pip3] umf==1.0.3
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

A small model with two chained x + Linear(RMSNorm(x)) residual blocks, where the residual buffer is built by scattering into a wider zero tensor, should compile and train fine under torch.compile(..., dynamic=True, fullgraph=True).

Expected: compiled forward + backward run the same as eager.

Actual: eager works. The compiled forward also works. Only the backward fails, with InductorError: CantSplit raised from _split_iteration_ranges. The failing divisibility check is of the form 128 · X · (N + 1) over X · (N + 1) — the quotient is the constant 128, but Inductor reports it as not divisible.

Minimal repro on a colab with t4 runtime

"""
Previous cell: !pip install --upgrade torch torchvision torchaudio
"""

import os
os.environ["TORCHDYNAMO_VERBOSE"] = "1"
import torch

print("Torch version:", torch.__version__) # Torch version: 2.12.0+cu130

class M(torch.nn.Module):
    def __init__(self, d=128):
        super().__init__()
        self.n1 = torch.nn.RMSNorm(d)
        self.l1 = torch.nn.Linear(d, d, bias=False)
        self.n2 = torch.nn.RMSNorm(d)
        self.l2 = torch.nn.Linear(d, d, bias=False)

    def forward(self, a, b):
        m = a.shape[0] // b.shape[0]
        out = torch.zeros(m * (b.shape[0] + 1), a.shape[-1], device=a.device)
        out[: a.shape[0]] = a
        out = out + self.l1(self.n1(out))
        out = out + self.l2(self.n2(out))
        return out.sum()

dev = "cuda"
a, b = torch.randn(21, 128, device=dev), torch.randn(7, device=dev)
print("Testing eager mode...")
M().to(dev)(a, b).backward()
print("Eager mode OK")
print("Testing compiled mode...")
torch.compile(M().to(dev), fullgraph=True, dynamic=True)(a, b).backward() # Crashes
print("Compiled mode OK")

Trigger conditions (empirically minimized)

Tried each of the following in isolation that made the bug disappears:

  1. Scatter (index_put) into a wider buffer using shape-derived indices (out[: a.shape[0]] = a). Replacing with torch.cat makes the bug vanish.
  2. torch.compile(..., dynamic=True) with two free shape symbols (T*N and N), so T = TN // N becomes a FloorDiv in the symbolic env.
  3. At least two x + Linear(RMSNorm(x)) residual blocks chained. One block, zero Linears, or zero RMSNorms compiles fine.

Error logs

Torch version: 2.12.0+cu130
Testing eager mode...
Eager mode OK
Testing compiled mode...
W0514 09:49:45.276000 2955 torch/_inductor/utils.py:1717] [0/0] Not enough SMs to use max_autotune_gemm mode
---------------------------------------------------------------------------
InductorError                             Traceback (most recent call last)
[/tmp/ipykernel_2955/4129170816.py](https://localhost:8080/#) in <cell line: 0>()
     28     print("Eager mode OK")
     29     print("Testing compiled mode...")
---> 30     torch.compile(M().to(dev), fullgraph=True, dynamic=True)(a, b).backward()
     31     print("Compiled mode OK")

61 frames
[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    471             )
    472         with _set_in_optimized_module():
--> 473             return super().__call__(*args, **kwargs)
    474 
    475     def _aot_compile(self, inputs: list[torch._dynamo.aot_compile.ModelInput]) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1776             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1777         else:
-> 1778             return self._call_impl(*args, **kwargs)
   1779 
   1780     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1787                 or _global_backward_pre_hooks or _global_backward_hooks
   1788                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1789             return forward_call(*args, **kwargs)
   1790 
   1791         result = None

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in compile_wrapper(*args, **kwargs)
   1060                     # Failures in the backend likely don't have useful
   1061                     # data in the TorchDynamo frames, so we strip them out.
-> 1062                     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
   1063                 finally:
   1064                     # Restore the dynamic layer stack depth if necessary.

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py](https://localhost:8080/#) in compile_wrapper(*args, **kwargs)
   1045                 call_succeeded = False
   1046                 try:
-> 1047                     result = fn(*args, **kwargs)
   1048                     call_succeeded = True
   1049                 except (Unsupported, UncapturedHigherOrderOpError, UserError) as e:

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1776             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1777         else:
-> 1778             return self._call_impl(*args, **kwargs)
   1779 
   1780     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1787                 or _global_backward_pre_hooks or _global_backward_hooks
   1788                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1789             return forward_call(*args, **kwargs)
   1790 
   1791         result = None

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in __call__(self, frame, cache_entry, frame_state)
   2472         with compile_lock, _disable_current_modes():
   2473             # skip=1: skip this frame
-> 2474             result = self._torchdynamo_orig_backend(
   2475                 frame, cache_entry, self.hooks, frame_state, skip=1
   2476             )

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in __call__(self, frame, cache_entry, hooks, frame_state, skip)
    734             )
    735             with compile_ctx, recompile_ctx:
--> 736                 result = _compile(
    737                     frame.f_code,
    738                     frame.f_globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip, package, convert_frame_box)
   1959 
   1960         try:
-> 1961             guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
   1962 
   1963             # NB: We only put_code_state in success case.  Success case here

[/usr/local/lib/python3.12/dist-packages/torch/_utils_internal.py](https://localhost:8080/#) in wrapper_function(*args, **kwargs)
     94             # in stack traces when profiling is not enabled.
     95             if not StrobelightCompileTimeProfiler.enabled:
---> 96                 return function(*args, **kwargs)
     97 
     98             return StrobelightCompileTimeProfiler.profile_compile_time(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_inner(code, one_graph, hooks)
   1569             stack.enter_context(CompileTimeInstructionCounter.record())
   1570             stack.enter_context(torch_function_mode_stack_state_mgr)
-> 1571             result = _compile_inner(code, one_graph, hooks)
   1572             assert torch._C._len_torch_function_stack() == 0, (
   1573                 "Torch function mode stack state changed while dynamo tracing, please report a bug"

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _compile_inner(code, one_graph, hooks)
   1628                 else contextlib.nullcontext()
   1629             ):
-> 1630                 dynamo_output = compile_frame(
   1631                     code,
   1632                     globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in compile_frame(code, globals, locals, builtins, closure, compiler_fn, one_graph, restart_reasons, export, export_constraints, frame_state, distributed_state, package)
   1476         try:
   1477             with dynamo_timed(f"compile_attempt_{attempt}", log_pt2_compile_event=True):
-> 1478                 bytecode, tracer_output = transform_code_object(code, transform)
   1479                 assert tracer_output is not None
   1480                 return DynamoOutput(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in transform_code_object(code, transformations, safe)
   1624     propagate_line_nums(instructions)
   1625 
-> 1626     tracer_output = transformations(instructions, code_options)
   1627     _, bytecode = clean_and_assemble_instructions(instructions, keys, code_options)
   1628     return bytecode, tracer_output

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in transform(instructions, code_options)
   1448             torch_function_mode_stack_state_mgr.stack
   1449         )
-> 1450         tracer_output = trace_frame(
   1451             code,
   1452             globals,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in _fn(*args, **kwargs)
    341             reset_user_object_tracking()
    342             try:
--> 343                 return fn(*args, **kwargs)
    344             finally:
    345                 cleanup.close()

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in trace_frame(code, globals, locals, builtins, closure, compiler_fn, tf_mode_stack, one_graph, speculation_log, instructions, code_options, export, export_constraints, frame_state, distributed_state, package)
    909 
    910     try:
--> 911         run_tracer()
    912         tracer_output = DynamoTracerOutput(tracer)
    913         output = tracer_output.output_graph

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/convert_frame.py](https://localhost:8080/#) in run_tracer()
    890             tracer.output.mark_bytecode_tracing_start()
    891             with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 892                 tracer.run()
    893         except exc.UnspecializeRestartAnalysis:
    894             speculation_log.clear()  # type: ignore[has-type]

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in run(self)
   1811                 self.start_point = self.instruction_pointer
   1812                 try:
-> 1813                     while self.step():
   1814                         pass
   1815                 except Exception as e:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in step(self)
   1478 
   1479         try:
-> 1480             self.dispatch_table[inst.opcode](self, inst)
   1481             return not self.output.should_exit
   1482         except TensorifyScalarRestartAnalysis:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in RETURN_VALUE(self, inst)
   5141 
   5142     def RETURN_VALUE(self, inst: Instruction) -> None:
-> 5143         self._return(inst)
   5144 
   5145     def RETURN_CONST(self, inst: Instruction) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/symbolic_convert.py](https://localhost:8080/#) in _return(self, inst)
   5123         )
   5124         log.debug("return triggered compile")
-> 5125         all_stack_locals_metadata = self.output.compile_subgraph(
   5126             self,
   5127             reason=GraphCompileReason(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_subgraph(self, tx, reason, stack_pops)
   2111             if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
   2112                 output.extend(
-> 2113                     self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
   2114                 )
   2115 

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in compile_and_call_fx_graph(self, tx, rv, root)
   2726 
   2727             with self.restore_global_state():
-> 2728                 compiled_fn = self.call_user_compiler(gm, self.example_inputs())
   2729 
   2730             from torch.fx._lazy_graph_module import _LazyGraphModule

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in call_user_compiler(self, gm, example_inputs)
   2893             dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
   2894         ):
-> 2895             return self._call_user_compiler(gm, example_inputs)
   2896 
   2897     def _call_user_compiler(

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py](https://localhost:8080/#) in _call_user_compiler(self, gm, example_inputs)
   2951             if config.verify_correctness:
   2952                 compiler_fn = WrapperBackend(compiler_fn)
-> 2953             compiled_fn = compiler_fn(gm, example_inputs)
   2954             _step_logger()(logging.INFO, f"done compiler function {name}")
   2955             assert callable(compiled_fn), "compiler_fn did not return callable"

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py](https://localhost:8080/#) in __call__(self, gm, example_inputs, **kwargs)
    154                     raise
    155         else:
--> 156             compiled_gm = compiler_fn(gm, example_inputs)
    157 
    158         return compiled_gm  # type: ignore[return-value]

[/usr/local/lib/python3.12/dist-packages/torch/__init__.py](https://localhost:8080/#) in __call__(self, model_, inputs_, config_patches)
   2480 
   2481         all_patches = {**self.config, **(config_patches or {})}
-> 2482         return compile_fx(
   2483             model_,
   2484             inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx(model_, example_inputs_, inner_compile, config_patches, decompositions, ignore_shape_env, compile_region_name)
   2732                 )
   2733 
-> 2734     return _maybe_wrap_and_compile_fx_main(
   2735         model_,
   2736         example_inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _maybe_wrap_and_compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   2813 
   2814     # Finally do the actual work!
-> 2815     return _compile_fx_main(
   2816         model_,
   2817         example_inputs_,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   3026                 # We will also shorten the traceback inside dynamo.
   3027                 # This is only useful if inductor is called directly with an FX graph.
-> 3028                 raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
   3029 
   3030 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_main(model_, example_inputs_, inner_compile, ignore_shape_env, get_decomp_fn, compile_region_name)
   3011         ):
   3012             try:
-> 3013                 return dynamo_common.aot_autograd(
   3014                     fw_compiler=fw_compiler,
   3015                     bw_compiler=bw_compiler,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/backends/common.py](https://localhost:8080/#) in __call__(self, gm, example_inputs, **kwargs)
    121             # NB: NOT cloned!
    122             with enable_aot_logging(), patch_config:
--> 123                 cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    124                 counters["aot_autograd"]["ok"] += 1
    125                 return disable(cg, reason="do not trace AOT-compiled graph")

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py](https://localhost:8080/#) in aot_module_simplified(mod, args, fw_compiler, bw_compiler, partition_fn, decompositions, keep_inference_input_mutations, inference_compiler, compiler_config_extra, ignore_shape_env, disable_functionalization, pre_grad_passes, compile_region_name)
   1232             aot_state.fw_metadata.act_input_indices = act_input_indices
   1233             aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call)
-> 1234             compiled_fn, _ = aot_stage2_compile(
   1235                 aot_state,
   1236                 aot_graph_capture,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in aot_stage2_compile(aot_state, aot_graph_capture, partition_fn, fw_compiler, bw_compiler, inference_compiler)
    376 
    377     if aot_state.needs_autograd and not aot_state.aot_config.pre_dispatch:
--> 378         return aot_stage2_autograd(aot_state, aot_graph_capture)
    379     else:
    380         return aot_stage2_inference(aot_state, aot_graph_capture)

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in aot_stage2_autograd(aot_state, aot_graph_capture)
   2284     )
   2285 
-> 2286     fwd_output_strides, compiled_fw_func = _aot_stage2b_fw_compile(
   2287         fw_module,
   2288         adjusted_flat_args,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in _aot_stage2b_fw_compile(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, num_fw_outs_saved_for_bw, aot_config)
   2073     # pyrefly: ignore [implicit-any]
   2074 ) -> tuple[list[tuple[int, ...] | None] | None, Callable]:
-> 2075     return _aot_stage2b_compile_forward_or_inference(
   2076         fw_module,
   2077         adjusted_flat_args,

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/graph_compile.py](https://localhost:8080/#) in _aot_stage2b_compile_forward_or_inference(fw_module, adjusted_flat_args, maybe_subclass_meta, fw_metadata, aot_config, is_inference, num_fw_outs_saved_for_bw)
   2601         with TracingContext.report_output_strides() as fwd_output_strides:
   2602             # pyrefly: ignore[not-callable]
-> 2603             compiled_fw_func = compiler(fw_module, adjusted_flat_args)
   2604 
   2605         # Make boxed if needed

[/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/schemas.py](https://localhost:8080/#) in __call__(self, gm, example_inputs)
   1419         example_inputs: Sequence[InputType],
   1420     ) -> OutputCode:
-> 1421         output_code = self.compiler_fn(gm, example_inputs)
   1422         return output_code
   1423 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in fw_compiler_base(gm, example_inputs, is_inference)
   2874                 else:
   2875                     num_orig_model_outputs = get_num_model_outputs(gm)
-> 2876                 return compile_fx_forward(
   2877                     gm,
   2878                     example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx_forward(gm, example_inputs, num_orig_model_outputs, num_example_inputs, compiler_config_extra, inner_compile, is_inference)
   2502 
   2503     with cudagraph_annotation_context(compiler_config_extra.cudagraphs):
-> 2504         result = inner_compile(
   2505             gm,
   2506             example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in compile_fx_inner(gm, example_inputs, compile_region_name, **kwargs)
    825             is_backward=kwargs["is_backward"],
    826         )
--> 827         return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
    828             gm,
    829             example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py](https://localhost:8080/#) in debug_wrapper(gm, example_inputs, compile_region_name, **kwargs)
    312             # Call the compiler_fn - which is either aot_autograd or inductor
    313             # with fake inputs
--> 314             inner_compiled_fn = compiler_fn(gm, example_inputs)
    315         except Exception:
    316             # TODO: Failures here are troublesome because no real inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs)
   1067                 raise
   1068             except Exception as e:
-> 1069                 raise InductorError(e, currentframe()).with_traceback(
   1070                     e.__traceback__
   1071                 ) from None

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in _compile_fx_inner(gm, example_inputs, compile_region_name, **graph_kwargs)
   1047             TritonBundler.begin_compile()
   1048             try:
-> 1049                 mb_compiled_graph = fx_codegen_and_compile(
   1050                     gm,
   1051                     example_inputs,

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in fx_codegen_and_compile(gm, example_inputs, inputs_to_check, compile_region_name, **graph_kwargs)
   1834 
   1835     # pyrefly: ignore [unbound-name]
-> 1836     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
   1837 
   1838 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py](https://localhost:8080/#) in codegen_and_compile(self, gm, example_inputs, inputs_to_check, graph_kwargs)
   1595                                 )
   1596                         else:
-> 1597                             compiled_module = graph.compile_to_module()
   1598                             compiled_fn = compiled_module.call
   1599                             compiled_fn_runner = getattr(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in compile_to_module(self)
   2611             dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
   2612         ):
-> 2613             return self._compile_to_module()
   2614 
   2615     def _compile_to_module(self) -> CompiledModule:

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in _compile_to_module(self)
   2617         # returned separately in AOTInductor mode.
   2618         wrapper_code, _ = (
-> 2619             self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
   2620         )
   2621 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/graph.py](https://localhost:8080/#) in codegen(self)
   2553 
   2554             self.wrapper_code.push_codegened_graph(self)
-> 2555             self.scheduler.codegen()
   2556 
   2557             log.debug(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in codegen(self)
   7348         with dynamo_timed("Scheduler.codegen"):
   7349             return (
-> 7350                 self._codegen_partitions()
   7351                 if torch._inductor.config.graph_partition
   7352                 else self._codegen(self.nodes)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in _codegen_partitions(self)
   7488 
   7489                 if signature.skip_cudagraph:
-> 7490                     self._codegen(partition)
   7491                 else:
   7492                     self._codegen_partition_wrapper(partition, signature)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/scheduler.py](https://localhost:8080/#) in _codegen(self, nodes)
   7639             elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
   7640                 # pyrefly: ignore [unbound-name]
-> 7641                 self.get_backend(device).codegen_node(node)
   7642             else:
   7643                 assert isinstance(node, NopKernelSchedulerNode)

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py](https://localhost:8080/#) in codegen_node(self, node)
    150 
    151     def codegen_node(self, node: FusedSchedulerNode | SchedulerNode) -> None:
--> 152         return self._triton_scheduling.codegen_node(node)
    153 
    154     def codegen_sync(self) -> None:

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node(self, node)
   1862             coalesce_analysis = None
   1863 
-> 1864         return self._codegen_nodes(nodes, coalesce_analysis)  # type: ignore[arg-type]
   1865 
   1866     @staticmethod

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in _codegen_nodes(self, nodes, coalesce_analysis)
   1835         schedule_log.debug("Schedule:\n %s", node_schedule)
   1836 
-> 1837         return self.codegen_node_schedule(
   1838             SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis)
   1839         )

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node_schedule(self, kernel_features)
   1950         )
   1951         for kernel in kernels:
-> 1952             self.codegen_node_schedule_with_kernel(node_schedule, kernel)
   1953         MultiKernel.merge_workspaces_inplace(kernels)
   1954 

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in codegen_node_schedule_with_kernel(self, node_schedule, kernel)
   2043                 else:
   2044                     node.decide_inplace_update()
-> 2045                     index_vars = kernel.split_and_set_ranges(node.get_ranges())
   2046                     all_indexing.update(
   2047                         dict.fromkeys(

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in split_and_set_ranges(self, lengths)
    941         # Map the kernel's group structure to the node's sizes and set the ranges
    942         # using the set_ranges method, returning the resulting iteration variables
--> 943         return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
    944 
    945     @classmethod

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in map_kernel_groups_to_node_sizes(cls, groups, lengths, set_ranges)
    968             return set_ranges(*lengths)
    969 
--> 970         new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths)
    971         itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))]
    972         return [[fn(itervars) for fn in fns] for fns in return_getters_groups]

[/usr/local/lib/python3.12/dist-packages/torch/_inductor/codegen/simd.py](https://localhost:8080/#) in _split_iteration_ranges(groups, lengths)
    845                         size, remaining[current_group]
    846                     ):
--> 847                         raise CantSplit(size, remaining[current_group])
    848 
    849                     size1 = remaining[current_group]

InductorError: CantSplit: 128*s52*((s97//s52)) + 128*((s97//s52)) not divisible by s52*((s97//s52)) + ((s97//s52))

Versions

Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.3) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.35

Python version: 3.12.13 (main, Mar  4 2026, 09:23:07) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.122+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: 
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 580.82.07
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  2
On-line CPU(s) list:                     0,1
Vendor ID:                               GenuineIntel
Model name:                              Intel(R) Xeon(R) CPU @ 2.00GHz
CPU family:                              6
Model:                                   85
Thread(s) per core:                      2
Core(s) per socket:                      1
Socket(s):                               1
Stepping:                                3
BogoMIPS:                                4000.44
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
Hypervisor vendor:                       KVM
Virtualization type:                     full
L1d cache:                               32 KiB (1 instance)
L1i cache:                               32 KiB (1 instance)
L2 cache:                                1 MiB (1 instance)
L3 cache:                                38.5 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0,1
Vulnerability Gather data sampling:      Not affected
Vulnerability Indirect target selection: Vulnerable
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Mitigation; PTE Inversion
Vulnerability Mds:                       Vulnerable; SMT Host state unknown
Vulnerability Meltdown:                  Vulnerable
Vulnerability Mmio stale data:           Vulnerable
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Vulnerable
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Vulnerable
Vulnerability Spectre v1:                Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:                Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Vulnerable
Vulnerability Vmscape:                   Not affected

Versions of relevant libraries:
[pip3] intel-cmplr-lib-ur==2025.3.3
[pip3] intel-openmp==2025.3.3
[pip3] mkl==2025.3.1
[pip3] numpy==2.0.2
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx==13.0.85
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] nvtx==0.2.15
[pip3] onemkl-license==2025.3.1
[pip3] optree==0.19.0
[pip3] tbb==2022.3.1
[pip3] tcmlib==1.4.1
[pip3] torch==2.12.0
[pip3] torchao==0.10.0
[pip3] torchaudio==2.11.0
[pip3] torchcodec==0.10.0+cu128
[pip3] torchdata==0.11.0
[pip3] torchsummary==1.5.1
[pip3] torchtune==0.6.1
[pip3] torchvision==0.27.0
[pip3] triton==3.7.0
[pip3] umf==1.0.3
[conda] Could not collect

cc @chauhang @penguinwu @ezyang @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile(dynamic=True)` backward fails with `InductorError: CantSplit` on chained residual blocks scattering into a wider buffer