pytorch - 💡(How to fix) Fix [Inductor] AssertionError: len(index) == len(stride) in ir.py when using torch.compile(mode="max-autotune") on CPU BMM/GEMM

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

import torch

def target_function( l_tok_embeddings_weight_, l_tokens_, l_attn_norm_0_weight_, l_wq_0_weight_, l_wk_0_weight_, l_wv_0_weight_, l_wo_0_weight_ ): arange = torch.arange(0, 8, 2, device='cpu', dtype=torch.float32) angles = torch.true_divide(arange, 8) pow_1 = torch.pow(10000.0, angles) freqs = torch.true_divide(1.0, pow_1) t = torch.arange(0, 4, device='cpu', dtype=torch.float32) outer = torch.outer(t, freqs) freqs_1 = outer.float() ones_like = torch.ones_like(freqs_1) freqs_cis = torch.polar(ones_like, freqs_1)

h = torch.nn.functional.embedding(l_tokens_, l_tok_embeddings_weight_)
x_normed = h.float()
pow_2 = x_normed.pow(2)
mean = pow_2.mean(-1, keepdim=True)
add = torch.add(mean, 1e-05)
rsqrt = torch.rsqrt(add)
x_normed_1 = torch.mul(x_normed, rsqrt)
type_as = x_normed_1.type_as(h)
normed_x = torch.mul(type_as, l_attn_norm_0_weight_)

q = torch.nn.functional.linear(normed_x, l_wq_0_weight_)
k = torch.nn.functional.linear(normed_x, l_wk_0_weight_)
v = torch.nn.functional.linear(normed_x, l_wv_0_weight_)

view = q.view(2, 4, 8, 8)
q_1 = view.transpose(1, 2)
view_1 = k.view(2, 4, 8, 8)
k_1 = view_1.transpose(1, 2)
view_2 = v.view(2, 4, 8, 8)
v_1 = view_2.transpose(1, 2)

float_3 = q_1.float()
reshape = float_3.reshape(2, 8, 4, -1, 2)
xq_ = torch.view_as_complex(reshape)

float_4 = k_1.float()
reshape_1 = float_4.reshape(2, 8, 4, -1, 2)
xk_ = torch.view_as_complex(reshape_1)

freqs_cis_1 = freqs_cis.reshape(1, 1, 4, -1)
mul_2 = torch.mul(xq_, freqs_cis_1)
view_as_real = torch.view_as_real(mul_2)
xq_out = view_as_real.flatten(3)

mul_3 = torch.mul(xk_, freqs_cis_1)
view_as_real_1 = torch.view_as_real(mul_3)
xk_out = view_as_real_1.flatten(3)

q_2 = xq_out.type_as(q_1)
k_2 = xk_out.type_as(k_1)

mask = torch.full([1, 1, 4, 4], float('-inf'), device='cpu', dtype=torch.float32)
mask_1 = torch.triu(mask, diagonal=1)
transpose_3 = k_2.transpose(-2, -1)
matmul = torch.matmul(q_2, transpose_3)

scores = torch.mul(matmul, 0.35355339059327373)
scores_1 = torch.add(scores, mask_1)
attn_weights = torch.nn.functional.softmax(scores_1, dim=-1)
attn_output = torch.matmul(attn_weights, v_1)

transpose_4 = attn_output.transpose(1, 2)
contiguous = transpose_4.contiguous()
attn_output_1 = contiguous.view(2, 4, 64)
linear_3 = torch.nn.functional.linear(attn_output_1, l_wo_0_weight_)

return (arange, t, mask, h, angles, mask_1, x_normed, pow_1, pow_2, freqs, mean, outer, add, freqs_1, rsqrt, ones_like, x_normed_1, freqs_cis, type_as, freqs_cis_1, normed_x, q, k, v, view, view_1, view_2, q_1, k_1, v_1, float_3, float_4, reshape, reshape_1, xq_, xk_, mul_2, mul_3, view_as_real, view_as_real_1, xq_out, xk_out, q_2, k_2, transpose_3, matmul, scores, scores_1, attn_weights, attn_output, transpose_4, contiguous, attn_output_1, linear_3)

def get_inputs(): torch.manual_seed(2077) return ( torch.randn([512, 64], dtype=torch.float32) * 0.1, torch.zeros([2, 4], dtype=torch.int64), torch.randn([64], dtype=torch.float32) * 0.1, torch.randn([64, 64], dtype=torch.float32) * 0.1, torch.randn([64, 64], dtype=torch.float32) * 0.1, torch.randn([64, 64], dtype=torch.float32) * 0.1, torch.randn([64, 64], dtype=torch.float32) * 0.1, )

if name == "main": inputs = get_inputs()

print("Running torch.compile(mode='max-autotune')...")
opt_fn = torch.compile(target_function, mode="max-autotune")

try:
    opt_fn(*inputs)
    print("Compilation Passed (Bug not triggered).")
except Exception as e:
    import traceback
    traceback.print_exc()
    print(f"\n[BINGO!] Crashed with: {type(e).__name__}: {e}")

Root Cause

A similar CUDA variant of this reproducer fails differently with KeyError: 'complex64' in Triton codegen. I filed that separately because the failing backend path is different.

Fix Action

Fix / Workaround

Traceback

(torch-nightly) xyt19@Oasis:/tmp$ TORCHDYNAMO_VERBOSE=1 python bug.py
Running torch.compile(mode='max-autotune')...
/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/lowering.py:2504: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
Traceback (most recent call last):
  File "/tmp/bug.py", line 95, in <module>
    opt_fn(*inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1158, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1141, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2619, in __call__
    result = self._torchdynamo_orig_backend(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2310, in __call__
    result = self._inner_convert(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 777, in __call__
    result = _compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2094, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1679, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1739, in _compile_inner
    dynamo_output = compile_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1584, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1766, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1555, in transform
    tracer_output = trace_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 368, in _fn
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 954, in trace_frame
    run_tracer()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 935, in run_tracer
    tracer.run()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1883, in run
    while self.step():
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1536, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5449, in RETURN_VALUE
    self._return(inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5422, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2171, in compile_subgraph
    instructions, subgraph_pycode = self.compile_and_call_fx_graph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2817, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2987, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 3049, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 159, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2705, in compile_fx
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2764, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2845, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3058, in _compile_fx_main
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3043, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1238, in aot_module_simplified
    compiled_fn, _ = aot_stage2_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 357, in aot_stage2_compile
    return aot_stage2_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 481, in aot_stage2_inference
    compiled_fw = _aot_stage2b_inference_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 408, in _aot_stage2b_inference_compile
    return _aot_stage2b_compile_forward_or_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 2779, in _aot_stage2b_compile_forward_or_inference
    compiled_fw_func = compiler(fw_module, adjusted_flat_args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/schemas.py", line 1460, in __call__
    output_code = self.compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2906, in fw_compiler_base
    return compile_fx_forward(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2534, in compile_fx_forward
    result = inner_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 836, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 317, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1078, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1058, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1845, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1606, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2661, in compile_to_module
    return self._compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2667, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2603, in codegen
    self.scheduler.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 8941, in codegen
    self._codegen_partitions()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9081, in _codegen_partitions
    self._codegen(partition)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9206, in _codegen
    self.get_backend(device).codegen_template(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp.py", line 5488, in codegen_template
    src_code = render()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 52, in render
    template.render(kernel=self, **kwargs), self.render_hooks
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 237, in render
    result = self._template_from_string(BMM_TEMPLATE).render(**options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 3, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 255, in codegen_single_thread_gemm
    return stub + self._template_from_string(GEMM_TEMPLATE).render(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 97, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 415, in store_output
    return self.store_pointwise_nodes(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 287, in store_pointwise_nodes
    body = LoopBody(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 136, in __init__
    self._init_with_tracing(fn, args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 166, in _init_with_tracing
    self.root_block = LoopBodyBlock(self, fn, args)  # traces
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 599, in __init__
    ops.output(fn(*args))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 284, in fn
    node.make_loader()(new_args).value,
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_utils.py", line 417, in inner
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_gemm_template.py", line 1484, in copy_inner
    input = input_loader(index)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 4744, in loader
    return ops.load(self.name or "unnamed", indexer(index))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 2135, in indexer
    assert stride is not None and len(index) == len(stride)
torch._inductor.exc.InductorError: AssertionError:

Code Example

import torch

def target_function(
    l_tok_embeddings_weight_, l_tokens_, l_attn_norm_0_weight_, 
    l_wq_0_weight_, l_wk_0_weight_, l_wv_0_weight_, l_wo_0_weight_
):
    arange = torch.arange(0, 8, 2, device='cpu', dtype=torch.float32)
    angles = torch.true_divide(arange, 8)
    pow_1 = torch.pow(10000.0, angles)
    freqs = torch.true_divide(1.0, pow_1)
    t = torch.arange(0, 4, device='cpu', dtype=torch.float32)
    outer = torch.outer(t, freqs)
    freqs_1 = outer.float()
    ones_like = torch.ones_like(freqs_1)
    freqs_cis = torch.polar(ones_like, freqs_1)
    
    h = torch.nn.functional.embedding(l_tokens_, l_tok_embeddings_weight_)
    x_normed = h.float()
    pow_2 = x_normed.pow(2)
    mean = pow_2.mean(-1, keepdim=True)
    add = torch.add(mean, 1e-05)
    rsqrt = torch.rsqrt(add)
    x_normed_1 = torch.mul(x_normed, rsqrt)
    type_as = x_normed_1.type_as(h)
    normed_x = torch.mul(type_as, l_attn_norm_0_weight_)
    
    q = torch.nn.functional.linear(normed_x, l_wq_0_weight_)
    k = torch.nn.functional.linear(normed_x, l_wk_0_weight_)
    v = torch.nn.functional.linear(normed_x, l_wv_0_weight_)
    
    view = q.view(2, 4, 8, 8)
    q_1 = view.transpose(1, 2)
    view_1 = k.view(2, 4, 8, 8)
    k_1 = view_1.transpose(1, 2)
    view_2 = v.view(2, 4, 8, 8)
    v_1 = view_2.transpose(1, 2)
    
    float_3 = q_1.float()
    reshape = float_3.reshape(2, 8, 4, -1, 2)
    xq_ = torch.view_as_complex(reshape)
    
    float_4 = k_1.float()
    reshape_1 = float_4.reshape(2, 8, 4, -1, 2)
    xk_ = torch.view_as_complex(reshape_1)
    
    freqs_cis_1 = freqs_cis.reshape(1, 1, 4, -1)
    mul_2 = torch.mul(xq_, freqs_cis_1)
    view_as_real = torch.view_as_real(mul_2)
    xq_out = view_as_real.flatten(3)
    
    mul_3 = torch.mul(xk_, freqs_cis_1)
    view_as_real_1 = torch.view_as_real(mul_3)
    xk_out = view_as_real_1.flatten(3)
    
    q_2 = xq_out.type_as(q_1)
    k_2 = xk_out.type_as(k_1)
    
    mask = torch.full([1, 1, 4, 4], float('-inf'), device='cpu', dtype=torch.float32)
    mask_1 = torch.triu(mask, diagonal=1)
    transpose_3 = k_2.transpose(-2, -1)
    matmul = torch.matmul(q_2, transpose_3)
    
    scores = torch.mul(matmul, 0.35355339059327373)
    scores_1 = torch.add(scores, mask_1)
    attn_weights = torch.nn.functional.softmax(scores_1, dim=-1)
    attn_output = torch.matmul(attn_weights, v_1)
    
    transpose_4 = attn_output.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    attn_output_1 = contiguous.view(2, 4, 64)
    linear_3 = torch.nn.functional.linear(attn_output_1, l_wo_0_weight_)
    
    return (arange, t, mask, h, angles, mask_1, x_normed, pow_1, pow_2, freqs, mean, outer, add, freqs_1, rsqrt, ones_like, x_normed_1, freqs_cis, type_as, freqs_cis_1, normed_x, q, k, v, view, view_1, view_2, q_1, k_1, v_1, float_3, float_4, reshape, reshape_1, xq_, xk_, mul_2, mul_3, view_as_real, view_as_real_1, xq_out, xk_out, q_2, k_2, transpose_3, matmul, scores, scores_1, attn_weights, attn_output, transpose_4, contiguous, attn_output_1, linear_3)


def get_inputs():
    torch.manual_seed(2077)
    return (
        torch.randn([512, 64], dtype=torch.float32) * 0.1,
        torch.zeros([2, 4], dtype=torch.int64),
        torch.randn([64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
    )

if __name__ == "__main__":
    inputs = get_inputs()
    
    print("Running torch.compile(mode='max-autotune')...")
    opt_fn = torch.compile(target_function, mode="max-autotune")
    
    try:
        opt_fn(*inputs)
        print("Compilation Passed (Bug not triggered).")
    except Exception as e:
        import traceback
        traceback.print_exc()
        print(f"\n[BINGO!] Crashed with: {type(e).__name__}: {e}")

---

(torch-nightly) xyt19@Oasis:/tmp$ TORCHDYNAMO_VERBOSE=1 python bug.py
Running torch.compile(mode='max-autotune')...
/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/lowering.py:2504: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
Traceback (most recent call last):
  File "/tmp/bug.py", line 95, in <module>
    opt_fn(*inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1158, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1141, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2619, in __call__
    result = self._torchdynamo_orig_backend(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2310, in __call__
    result = self._inner_convert(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 777, in __call__
    result = _compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2094, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1679, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1739, in _compile_inner
    dynamo_output = compile_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1584, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1766, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1555, in transform
    tracer_output = trace_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 368, in _fn
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 954, in trace_frame
    run_tracer()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 935, in run_tracer
    tracer.run()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1883, in run
    while self.step():
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1536, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5449, in RETURN_VALUE
    self._return(inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5422, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2171, in compile_subgraph
    instructions, subgraph_pycode = self.compile_and_call_fx_graph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2817, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2987, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 3049, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 159, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2705, in compile_fx
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2764, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2845, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3058, in _compile_fx_main
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3043, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1238, in aot_module_simplified
    compiled_fn, _ = aot_stage2_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 357, in aot_stage2_compile
    return aot_stage2_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 481, in aot_stage2_inference
    compiled_fw = _aot_stage2b_inference_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 408, in _aot_stage2b_inference_compile
    return _aot_stage2b_compile_forward_or_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 2779, in _aot_stage2b_compile_forward_or_inference
    compiled_fw_func = compiler(fw_module, adjusted_flat_args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/schemas.py", line 1460, in __call__
    output_code = self.compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2906, in fw_compiler_base
    return compile_fx_forward(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2534, in compile_fx_forward
    result = inner_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 836, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 317, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1078, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1058, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1845, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1606, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2661, in compile_to_module
    return self._compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2667, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2603, in codegen
    self.scheduler.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 8941, in codegen
    self._codegen_partitions()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9081, in _codegen_partitions
    self._codegen(partition)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9206, in _codegen
    self.get_backend(device).codegen_template(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp.py", line 5488, in codegen_template
    src_code = render()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 52, in render
    template.render(kernel=self, **kwargs), self.render_hooks
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 237, in render
    result = self._template_from_string(BMM_TEMPLATE).render(**options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 3, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 255, in codegen_single_thread_gemm
    return stub + self._template_from_string(GEMM_TEMPLATE).render(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 97, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 415, in store_output
    return self.store_pointwise_nodes(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 287, in store_pointwise_nodes
    body = LoopBody(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 136, in __init__
    self._init_with_tracing(fn, args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 166, in _init_with_tracing
    self.root_block = LoopBodyBlock(self, fn, args)  # traces
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 599, in __init__
    ops.output(fn(*args))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 284, in fn
    node.make_loader()(new_args).value,
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_utils.py", line 417, in inner
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_gemm_template.py", line 1484, in copy_inner
    input = input_loader(index)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 4744, in loader
    return ops.load(self.name or "unnamed", indexer(index))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 2135, in indexer
    assert stride is not None and len(index) == len(stride)
torch._inductor.exc.InductorError: AssertionError:


[BINGO!] Crashed with: InductorError: AssertionError:
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When compiling an attention-like function containing complex numbers (torch.polar, torch.view_as_complex) and batched matrix multiplication (torch.matmul) on CPU with torch.compile(mode="max-autotune"), the compiler crashes with an AssertionError.

The crash occurs during the C++ backend's BMM/GEMM template generation (cpp_bmm_template.py), specifically in the indexer function inside torch/_inductor/ir.py where it asserts len(index) == len(stride).

A similar CUDA variant of this reproducer fails differently with KeyError: 'complex64' in Triton codegen. I filed that separately because the failing backend path is different.

Reproduction Script

import torch

def target_function(
    l_tok_embeddings_weight_, l_tokens_, l_attn_norm_0_weight_, 
    l_wq_0_weight_, l_wk_0_weight_, l_wv_0_weight_, l_wo_0_weight_
):
    arange = torch.arange(0, 8, 2, device='cpu', dtype=torch.float32)
    angles = torch.true_divide(arange, 8)
    pow_1 = torch.pow(10000.0, angles)
    freqs = torch.true_divide(1.0, pow_1)
    t = torch.arange(0, 4, device='cpu', dtype=torch.float32)
    outer = torch.outer(t, freqs)
    freqs_1 = outer.float()
    ones_like = torch.ones_like(freqs_1)
    freqs_cis = torch.polar(ones_like, freqs_1)
    
    h = torch.nn.functional.embedding(l_tokens_, l_tok_embeddings_weight_)
    x_normed = h.float()
    pow_2 = x_normed.pow(2)
    mean = pow_2.mean(-1, keepdim=True)
    add = torch.add(mean, 1e-05)
    rsqrt = torch.rsqrt(add)
    x_normed_1 = torch.mul(x_normed, rsqrt)
    type_as = x_normed_1.type_as(h)
    normed_x = torch.mul(type_as, l_attn_norm_0_weight_)
    
    q = torch.nn.functional.linear(normed_x, l_wq_0_weight_)
    k = torch.nn.functional.linear(normed_x, l_wk_0_weight_)
    v = torch.nn.functional.linear(normed_x, l_wv_0_weight_)
    
    view = q.view(2, 4, 8, 8)
    q_1 = view.transpose(1, 2)
    view_1 = k.view(2, 4, 8, 8)
    k_1 = view_1.transpose(1, 2)
    view_2 = v.view(2, 4, 8, 8)
    v_1 = view_2.transpose(1, 2)
    
    float_3 = q_1.float()
    reshape = float_3.reshape(2, 8, 4, -1, 2)
    xq_ = torch.view_as_complex(reshape)
    
    float_4 = k_1.float()
    reshape_1 = float_4.reshape(2, 8, 4, -1, 2)
    xk_ = torch.view_as_complex(reshape_1)
    
    freqs_cis_1 = freqs_cis.reshape(1, 1, 4, -1)
    mul_2 = torch.mul(xq_, freqs_cis_1)
    view_as_real = torch.view_as_real(mul_2)
    xq_out = view_as_real.flatten(3)
    
    mul_3 = torch.mul(xk_, freqs_cis_1)
    view_as_real_1 = torch.view_as_real(mul_3)
    xk_out = view_as_real_1.flatten(3)
    
    q_2 = xq_out.type_as(q_1)
    k_2 = xk_out.type_as(k_1)
    
    mask = torch.full([1, 1, 4, 4], float('-inf'), device='cpu', dtype=torch.float32)
    mask_1 = torch.triu(mask, diagonal=1)
    transpose_3 = k_2.transpose(-2, -1)
    matmul = torch.matmul(q_2, transpose_3)
    
    scores = torch.mul(matmul, 0.35355339059327373)
    scores_1 = torch.add(scores, mask_1)
    attn_weights = torch.nn.functional.softmax(scores_1, dim=-1)
    attn_output = torch.matmul(attn_weights, v_1)
    
    transpose_4 = attn_output.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    attn_output_1 = contiguous.view(2, 4, 64)
    linear_3 = torch.nn.functional.linear(attn_output_1, l_wo_0_weight_)
    
    return (arange, t, mask, h, angles, mask_1, x_normed, pow_1, pow_2, freqs, mean, outer, add, freqs_1, rsqrt, ones_like, x_normed_1, freqs_cis, type_as, freqs_cis_1, normed_x, q, k, v, view, view_1, view_2, q_1, k_1, v_1, float_3, float_4, reshape, reshape_1, xq_, xk_, mul_2, mul_3, view_as_real, view_as_real_1, xq_out, xk_out, q_2, k_2, transpose_3, matmul, scores, scores_1, attn_weights, attn_output, transpose_4, contiguous, attn_output_1, linear_3)


def get_inputs():
    torch.manual_seed(2077)
    return (
        torch.randn([512, 64], dtype=torch.float32) * 0.1,
        torch.zeros([2, 4], dtype=torch.int64),
        torch.randn([64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
        torch.randn([64, 64], dtype=torch.float32) * 0.1,
    )

if __name__ == "__main__":
    inputs = get_inputs()
    
    print("Running torch.compile(mode='max-autotune')...")
    opt_fn = torch.compile(target_function, mode="max-autotune")
    
    try:
        opt_fn(*inputs)
        print("Compilation Passed (Bug not triggered).")
    except Exception as e:
        import traceback
        traceback.print_exc()
        print(f"\n[BINGO!] Crashed with: {type(e).__name__}: {e}")

Traceback

(torch-nightly) xyt19@Oasis:/tmp$ TORCHDYNAMO_VERBOSE=1 python bug.py
Running torch.compile(mode='max-autotune')...
/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/lowering.py:2504: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
Traceback (most recent call last):
  File "/tmp/bug.py", line 95, in <module>
    opt_fn(*inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1158, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1141, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2619, in __call__
    result = self._torchdynamo_orig_backend(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2310, in __call__
    result = self._inner_convert(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 777, in __call__
    result = _compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 2094, in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1679, in compile_inner
    result = _compile_inner(code, one_graph, hooks)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1739, in _compile_inner
    dynamo_output = compile_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1584, in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1766, in transform_code_object
    tracer_output = transformations(instructions, code_options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1555, in transform
    tracer_output = trace_frame(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 368, in _fn
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 954, in trace_frame
    run_tracer()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 935, in run_tracer
    tracer.run()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1883, in run
    while self.step():
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1536, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5449, in RETURN_VALUE
    self._return(inst)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 5422, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2171, in compile_subgraph
    instructions, subgraph_pycode = self.compile_and_call_fx_graph(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2817, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 2987, in call_user_compiler
    return self._call_user_compiler(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 3049, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 159, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/__init__.py", line 2482, in __call__
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2705, in compile_fx
    return compile_fx(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2764, in compile_fx
    return _maybe_wrap_and_compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2845, in _maybe_wrap_and_compile_fx_main
    return _compile_fx_main(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3058, in _compile_fx_main
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 3043, in _compile_fx_main
    return dynamo_common.aot_autograd(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 123, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1238, in aot_module_simplified
    compiled_fn, _ = aot_stage2_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 357, in aot_stage2_compile
    return aot_stage2_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 481, in aot_stage2_inference
    compiled_fw = _aot_stage2b_inference_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 408, in _aot_stage2b_inference_compile
    return _aot_stage2b_compile_forward_or_inference(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 2779, in _aot_stage2b_compile_forward_or_inference
    compiled_fw_func = compiler(fw_module, adjusted_flat_args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/schemas.py", line 1460, in __call__
    output_code = self.compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2906, in fw_compiler_base
    return compile_fx_forward(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 2534, in compile_fx_forward
    result = inner_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 836, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 317, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1078, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1058, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1845, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1606, in codegen_and_compile
    compiled_module = graph.compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2661, in compile_to_module
    return self._compile_to_module()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2667, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2603, in codegen
    self.scheduler.codegen()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 8941, in codegen
    self._codegen_partitions()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9081, in _codegen_partitions
    self._codegen(partition)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/scheduler.py", line 9206, in _codegen
    self.get_backend(device).codegen_template(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp.py", line 5488, in codegen_template
    src_code = render()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 52, in render
    template.render(kernel=self, **kwargs), self.render_hooks
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 237, in render
    result = self._template_from_string(BMM_TEMPLATE).render(**options)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 3, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_bmm_template.py", line 255, in codegen_single_thread_gemm
    return stub + self._template_from_string(GEMM_TEMPLATE).render(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 1295, in render
    self.environment.handle_exception()
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/jinja2/environment.py", line 942, in handle_exception
    raise rewrite_traceback_stack(source=source)
  File "<template>", line 97, in top-level template code
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 415, in store_output
    return self.store_pointwise_nodes(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 287, in store_pointwise_nodes
    body = LoopBody(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 136, in __init__
    self._init_with_tracing(fn, args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 166, in _init_with_tracing
    self.root_block = LoopBodyBlock(self, fn, args)  # traces
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/loop_body.py", line 599, in __init__
    ops.output(fn(*args))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_template_kernel.py", line 284, in fn
    node.make_loader()(new_args).value,
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_utils.py", line 417, in inner
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/codegen/cpp_gemm_template.py", line 1484, in copy_inner
    input = input_loader(index)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 4744, in loader
    return ops.load(self.name or "unnamed", indexer(index))
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_inductor/ir.py", line 2135, in indexer
    assert stride is not None and len(index) == len(stride)
torch._inductor.exc.InductorError: AssertionError:


[BINGO!] Crashed with: InductorError: AssertionError:

Versions

PyTorch version: 2.13.0.dev20260521+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: 18.1.3 (1ubuntu1) CMake version: version 3.28.3 Libc version: glibc-2.39

Python version: 3.10.20 (main, Mar 11 2026, 17:46:40) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.0.140 Nvidia driver version: 596.49 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_tensor_ir.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.21.1 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A ersions of relevant libraries: [pip3] numpy==2.2.6 [pip3] nvidia-cublas==13.1.1.3 [pip3] nvidia-cuda-cupti==13.0.85 [pip3] nvidia-cuda-nvrtc==13.0.88 [pip3] nvidia-cuda-runtime==13.0.96 [pip3] nvidia-cudnn-cu13==9.20.0.48 [pip3] nvidia-cufft==12.0.0.61 [pip3] nvidia-curand==10.4.0.35 [pip3] nvidia-cusolver==12.0.4.66 [pip3] nvidia-cusparse==12.6.3.3 [pip3] nvidia-cusparselt-cu13==0.8.1 [pip3] nvidia-nccl-cu13==2.29.7 [pip3] nvidia-nvjitlink==13.0.88 [pip3] nvidia-nvtx==13.0.85 [pip3] torch==2.13.0.dev20260521+cu130 [pip3] torchaudio==2.11.0.dev20260525+cu130 [pip3] torchvision==0.28.0.dev20260525+cu130 [pip3] triton==3.7.0+git88b227e2 [conda] numpy 2.2.6 pypi_0 pypi [conda] nvidia-cublas 13.1.1.3 pypi_0 pypi [conda] nvidia-cuda-cupti 13.0.85 pypi_0 pypi [conda] nvidia-cuda-nvrtc 13.0.88 pypi_0 pypi [conda] nvidia-cuda-runtime 13.0.96 pypi_0 pypi [conda] nvidia-cudnn-cu13 9.20.0.48 pypi_0 pypi [conda] nvidia-cufft 12.0.0.61 pypi_0 pypi [conda] nvidia-curand 10.4.0.35 pypi_0 pypi [conda] nvidia-cusolver 12.0.4.66 pypi_0 pypi [conda] nvidia-cusparse 12.6.3.3 pypi_0 pypi [conda] nvidia-cusparselt-cu13 0.8.1 pypi_0 pypi [conda] nvidia-nccl-cu13 2.29.7 pypi_0 pypi [conda] nvidia-nvjitlink 13.0.88 pypi_0 pypi [conda] nvidia-nvtx 13.0.85 pypi_0 pypi [conda] torch 2.13.0.dev20260521+cu130 pypi_0 pypi [conda] torchaudio 2.11.0.dev20260525+cu130 pypi_0 pypi [conda] torchvision 0.28.0.dev20260525+cu130 pypi_0 pypi [conda] triton 3.7.0+git88b227e2 pypi_0 pypi

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @aditew01 @ezyang @anjali411 @dylanbespalko @mruberry @nikitaved @amjames @chauhang @penguinwu @voznesenskym @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [Inductor] AssertionError: len(index) == len(stride) in ir.py when using torch.compile(mode="max-autotune") on CPU BMM/GEMM