pytorch - 💡(How to fix) Fix [PT2] RuntimeError with garbage shape in apply_view_meta_sequence when using torch.compile(dynamic=True)

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

(torch-nightly) xyt19@Oasis:/tmp$ python bug.py PyTorch Version: 2.13.0.dev20260521+cu130

[1/2] Running Eager execution... [PASS] Eager mode completed successfully.

[2/2] Running Compiled execution (dynamic=True)...

================================================================================ [BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.

Traceback (most recent call last): File "/tmp/bug.py", line 385, in main _ = run_function(compiled_fn, raw_inputs) File "/tmp/bug.py", line 359, in run_function output = fn(*inputs) File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/dynamo/eval_frame.py", line 1141, in compile_wrapper result = fn(*args, **kwargs) File "/tmp/bug.py", line 209, in model_forward def model_forward(l_q_weight, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_): File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1415, in _fn return fn(*args, **kwargs) File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1277, in forward return compiled_fn(full_args) File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2551, in call return self.compiled_fn(*args, **kwargs) File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1126, in runtime_wrapper result = _codegen_runtime_wrapper( File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(runtime_wrapper_orchestration)", line 12, in _runtime_wrapper File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 968, in _replay_alias return _codegen_alias_fn(orig_inputs, fw_outs) File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(output_alias_wrapper)", line 16, in _alias_fn File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 333, in gen_alias_from_base out = _functionalization.apply_view_meta_sequence( RuntimeError: shape '[93989549545813, 6741088537664025784, 24]' is invalid for input of size 288

Fix Action

Fix / Workaround

def process_patches(l_image_: torch.Tensor, l_kernel_bank_: torch.Tensor, l_selector_: torch.Tensor, l_residual_weight_: torch.Tensor, l_channel_scale_: torch.Tensor) -> torch.Tensor: patches = torch.nn.functional.unfold(l_image_, kernel_size=3, padding=1, stride=1) transpose = patches.transpose(1, 2) patches_1 = transpose.contiguous() selected = torch.index_select(l_kernel_bank_, dim=0, index=l_selector_) selected_1 = selected.view(2, 36, 4) bmm = torch.bmm(patches_1, selected_1) transpose_1 = bmm.transpose(1, 2) filtered = transpose_1.contiguous() filtered_1 = filtered.view(2, 4, 6, 6) sigmoid = torch.nn.functional.sigmoid(filtered_1) tanh = torch.nn.functional.tanh(filtered_1) gated = torch.mul(sigmoid, tanh) pooled = torch.nn.functional.avg_pool2d(gated, kernel_size=3, stride=1, padding=1) view_2 = l_residual_weight_.view(1, 4, 1, 1) residual = torch.mul(l_image_, view_2) add = torch.add(pooled, residual) view_3 = l_channel_scale_.view(1, 4, 1, 1) mul_2 = torch.mul(add, view_3) return mul_2

def model_forward(l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_): g5_call = fusion_layer_2(g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_) linear = torch.nn.functional.linear(g5_call, l_q_weight_) view = linear.view(2, 6, 4, 6) q = view.transpose(1, 2) linear_1 = torch.nn.functional.linear(g5_call, l_k_weight_) view_1 = linear_1.view(2, 6, 4, 6) k = view_1.transpose(1, 2) linear_2 = torch.nn.functional.linear(g5_call, l_v_weight_) view_2 = linear_2.view(2, 6, 4, 6) v = view_2.transpose(1, 2) transpose_3 = k.transpose(-2, -1) matmul = torch.matmul(q, transpose_3) scores = torch.true_divide(matmul, 2.449489742783178) arange = torch.arange(6) view_3 = arange.view(1, -1) arange_1 = torch.arange(6) view_4 = arange_1.view(-1, 1) band = torch.sub(view_3, view_4) abs_1 = band.abs() band_mask = torch.gt(abs_1, 4) view_5 = l_rel_bias_.view(1, 4, 6, 6) scores_1 = torch.add(scores, view_5) view_6 = l_mask_.view(2, 1, 1, 6) scores_2 = scores_1.masked_fill(view_6, -10000.0) view_7 = band_mask.view(1, 1, 6, 6) scores_3 = scores_2.masked_fill(view_7, -10000.0) probs = torch.nn.functional.softmax(scores_3, dim=-1) matmul_1 = torch.matmul(probs, v) transpose_4 = matmul_1.transpose(1, 2) contiguous = transpose_4.contiguous() context = contiguous.view(2, 6, 24) projected = torch.nn.functional.linear(context, l_out_weight_) sigmoid = torch.nn.functional.sigmoid(g5_call) gated = torch.mul(projected, sigmoid) add_1 = torch.add(g5_call, gated) layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05) g3_call = attention_layer_2(linear_1, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_) g6_call = attention_layer_3(linear_1, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_) g2_call = process_patches(scores_2, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_) g1_call = attention_layer_1(gated, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_) g4_call = fusion_layer_1(add_1, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_) return arange, arange_1, view_5, view_6, view_3, view_4, g5_call, band, linear, linear_1, linear_2, sigmoid, abs_1, view, view_1, g3_call, g6_call, view_2, band_mask, q, k, v, view_7, transpose_3, matmul, scores, scores_1, scores_2, scores_3, g2_call, probs, matmul_1, transpose_4, contiguous, context, projected, gated, add_1, g1_call, layer_norm, g4_call

try: _ = run_function(compiled_fn, raw_inputs) print(" [PASS] Compiled mode completed successfully. (Bug might be patched)") except Exception as e: if "invalid for input of size" in str(e): print("\n" + "="*80) print("[BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.") print("="*80) traceback.print_exc() print("="*80) else: print("\n [FAIL] Compiled mode failed with an unexpected error:") traceback.print_exc()

Code Example

(torch-nightly) xyt19@Oasis:/tmp$ python bug.py
PyTorch Version: 2.13.0.dev20260521+cu130

[1/2] Running Eager execution...
      [PASS] Eager mode completed successfully.

[2/2] Running Compiled execution (dynamic=True)...

================================================================================
[BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.
================================================================================
Traceback (most recent call last):
  File "/tmp/bug.py", line 385, in main
    _ = run_function(compiled_fn, raw_inputs)
  File "/tmp/bug.py", line 359, in run_function
    output = fn(*inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1141, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/tmp/bug.py", line 209, in model_forward
    def model_forward(l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_):
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1415, in _fn
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1277, in forward
    return compiled_fn(full_args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2551, in __call__
    return self.compiled_fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1126, in runtime_wrapper
    result = _codegen_runtime_wrapper(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(runtime_wrapper_orchestration)", line 12, in _runtime_wrapper
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 968, in _replay_alias
    return _codegen_alias_fn(orig_inputs, fw_outs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(output_alias_wrapper)", line 16, in _alias_fn
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 333, in gen_alias_from_base
    out = _functionalization.apply_view_meta_sequence(
RuntimeError: shape '[93989549545813, 6741088537664025784, 24]' is invalid for input of size 288
================================================================================

---

import torch
import copy
import traceback
import warnings

warnings.filterwarnings("ignore")

def process_patches(l_image_: torch.Tensor, l_kernel_bank_: torch.Tensor, l_selector_: torch.Tensor, l_residual_weight_: torch.Tensor, l_channel_scale_: torch.Tensor) -> torch.Tensor:
    patches = torch.nn.functional.unfold(l_image_, kernel_size=3, padding=1, stride=1)
    transpose = patches.transpose(1, 2)
    patches_1 = transpose.contiguous()
    selected = torch.index_select(l_kernel_bank_, dim=0, index=l_selector_)
    selected_1 = selected.view(2, 36, 4)
    bmm = torch.bmm(patches_1, selected_1)
    transpose_1 = bmm.transpose(1, 2)
    filtered = transpose_1.contiguous()
    filtered_1 = filtered.view(2, 4, 6, 6)
    sigmoid = torch.nn.functional.sigmoid(filtered_1)
    tanh = torch.nn.functional.tanh(filtered_1)
    gated = torch.mul(sigmoid, tanh)
    pooled = torch.nn.functional.avg_pool2d(gated, kernel_size=3, stride=1, padding=1)
    view_2 = l_residual_weight_.view(1, 4, 1, 1)
    residual = torch.mul(l_image_, view_2)
    add = torch.add(pooled, residual)
    view_3 = l_channel_scale_.view(1, 4, 1, 1)
    mul_2 = torch.mul(add, view_3)
    return mul_2

def attention_layer_1(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def attention_layer_2(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def attention_layer_3(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def fusion_layer_1(l_x_: torch.Tensor, l_aux_tokens_: torch.Tensor, l_q_weight_: torch.Tensor, l_kv_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_merge_weight_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    merged = torch.cat([l_aux_tokens_, l_x_], dim=1)
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(merged, l_kv_weight_)
    kv = linear_1.view(2, 11, 2, 4, 6)
    permute = kv.permute(2, 0, 3, 1, 4)
    kv_1 = permute.contiguous()
    getitem = kv_1[0]
    transpose_1 = getitem.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_1)
    scores = torch.true_divide(matmul, 2.449489742783178)
    view_2 = l_rel_bias_.view(1, 4, 6, 11)
    scores_1 = torch.add(scores, view_2)
    view_3 = l_mask_.view(2, 1, 1, 11)
    scores_2 = scores_1.masked_fill(view_3, -10000.0)
    probs = torch.nn.functional.softmax(scores_2, dim=-1)
    getitem_1 = kv_1[1]
    matmul_1 = torch.matmul(probs, getitem_1)
    transpose_2 = matmul_1.transpose(1, 2)
    contiguous_1 = transpose_2.contiguous()
    context = contiguous_1.view(2, 6, 24)
    mean = merged.mean(dim=1, keepdim=True)
    pooled = mean.expand(-1, 6, -1)
    cat_1 = torch.cat([context, pooled], dim=-1)
    fused = torch.nn.functional.linear(cat_1, l_merge_weight_)
    gelu = torch.nn.functional.gelu(fused)
    projected = torch.nn.functional.linear(gelu, l_out_weight_)
    add_1 = torch.add(l_x_, projected)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def fusion_layer_2(l_x_: torch.Tensor, l_aux_tokens_: torch.Tensor, l_q_weight_: torch.Tensor, l_kv_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_merge_weight_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    merged = torch.cat([l_aux_tokens_, l_x_], dim=1)
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(merged, l_kv_weight_)
    kv = linear_1.view(2, 11, 2, 4, 6)
    permute = kv.permute(2, 0, 3, 1, 4)
    kv_1 = permute.contiguous()
    getitem = kv_1[0]
    transpose_1 = getitem.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_1)
    scores = torch.true_divide(matmul, 2.449489742783178)
    view_2 = l_rel_bias_.view(1, 4, 6, 11)
    scores_1 = torch.add(scores, view_2)
    view_3 = l_mask_.view(2, 1, 1, 11)
    scores_2 = scores_1.masked_fill(view_3, -10000.0)
    probs = torch.nn.functional.softmax(scores_2, dim=-1)
    getitem_1 = kv_1[1]
    matmul_1 = torch.matmul(probs, getitem_1)
    transpose_2 = matmul_1.transpose(1, 2)
    contiguous_1 = transpose_2.contiguous()
    context = contiguous_1.view(2, 6, 24)
    mean = merged.mean(dim=1, keepdim=True)
    pooled = mean.expand(-1, 6, -1)
    cat_1 = torch.cat([context, pooled], dim=-1)
    fused = torch.nn.functional.linear(cat_1, l_merge_weight_)
    gelu = torch.nn.functional.gelu(fused)
    projected = torch.nn.functional.linear(gelu, l_out_weight_)
    add_1 = torch.add(l_x_, projected)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def model_forward(l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_):
    g5_call = fusion_layer_2(g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_)
    linear = torch.nn.functional.linear(g5_call, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(g5_call, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(g5_call, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(g5_call)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(g5_call, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    g3_call = attention_layer_2(linear_1, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_)
    g6_call = attention_layer_3(linear_1, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_)
    g2_call = process_patches(scores_2, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_)
    g1_call = attention_layer_1(gated, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_)
    g4_call = fusion_layer_1(add_1, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_)
    return arange, arange_1, view_5, view_6, view_3, view_4, g5_call, band, linear, linear_1, linear_2, sigmoid, abs_1, view, view_1, g3_call, g6_call, view_2, band_mask, q, k, v, view_7, transpose_3, matmul, scores, scores_1, scores_2, scores_3, g2_call, probs, matmul_1, transpose_4, contiguous, context, projected, gated, add_1, g1_call, layer_norm, g4_call


def get_inputs():
    l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    l_mask_ = (torch.rand([2, 6]) > 0.5)
    l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g5_l_x_ = (torch.randn([2, 6, 24], dtype=torch.float32) * 0.1)
    g5_l_aux_tokens_ = (torch.randn([2, 5, 24], dtype=torch.float32) * 0.1)
    g5_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g5_l_kv_weight_ = (torch.randn([48, 24], dtype=torch.float32) * 0.1)
    g5_l_rel_bias_ = torch.zeros([4, 6, 11], dtype=torch.float32)
    g5_l_mask_ = (torch.rand([2, 11]) > 0.5)
    g5_l_merge_weight_ = (torch.randn([24, 48], dtype=torch.float32) * 0.1)
    g5_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g5_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g5_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g3_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g3_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g3_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g3_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g6_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g6_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g6_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g6_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g2_l_kernel_bank_ = (torch.randn([6, 36, 4], dtype=torch.float32) * 0.1)
    g2_l_selector_ = torch.zeros([2], dtype=torch.int64)
    g2_l_residual_weight_ = (torch.randn([4], dtype=torch.float32) * 0.1)
    g2_l_channel_scale_ = (torch.rand([4], dtype=torch.float32) * 0.1 + 1.0)
    g1_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g1_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g1_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g1_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g4_l_aux_tokens_ = (torch.randn([2, 5, 24], dtype=torch.float32) * 0.1)
    g4_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g4_l_kv_weight_ = (torch.randn([48, 24], dtype=torch.float32) * 0.1)
    g4_l_rel_bias_ = torch.zeros([4, 6, 11], dtype=torch.float32)
    g4_l_mask_ = (torch.rand([2, 11]) > 0.5)
    g4_l_merge_weight_ = (torch.randn([24, 48], dtype=torch.float32) * 0.1)
    g4_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g4_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g4_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    return (l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_)

def flatten_tensors(obj):
    if isinstance(obj, torch.Tensor):
        return [obj]
    if isinstance(obj, (list, tuple)):
        out = []
        for item in obj:
            out.extend(flatten_tensors(item))
        return out
    if isinstance(obj, dict):
        out = []
        for key in sorted(obj, key=str):
            out.extend(flatten_tensors(obj[key]))
        return out
    return []

def map_tensors(obj, fn):
    if isinstance(obj, torch.Tensor):
        return fn(obj)
    if isinstance(obj, tuple):
        return tuple(map_tensors(item, fn) for item in obj)
    if isinstance(obj, list):
        return [map_tensors(item, fn) for item in obj]
    return obj

def loss_from_output(obj):
    terms = []
    for tensor in flatten_tensors(obj):
        if tensor.is_floating_point():
            terms.append(tensor.float().sum())
    if not terms:
        return None
    loss = terms[0]
    for term in terms[1:]:
        loss = loss + term
    return loss

def prepare_inputs(raw_inputs):
    def prepare_tensor(t):
        if t.is_floating_point() or t.is_complex():
            t = t.detach().clone()
            t.requires_grad_(True)
        return t
    return map_tensors(copy.deepcopy(raw_inputs), prepare_tensor)

def run_function(fn, raw_inputs):
    inputs = prepare_inputs(raw_inputs)
    with torch.enable_grad():
        output = fn(*inputs)
        loss = loss_from_output(output)
        if loss is not None and getattr(loss, "requires_grad", False):
            loss.backward()
        return output

def main():
    print(f"PyTorch Version: {torch.__version__}")
    
    torch.manual_seed(2077)
    
    raw_inputs = get_inputs()
    
    print("\n[1/2] Running Eager execution...")
    try:
        _ = run_function(model_forward, raw_inputs)
        print("      [PASS] Eager mode completed successfully.")
    except Exception as e:
        print("      [FAIL] Eager mode failed:")
        traceback.print_exc()
        return

    print("\n[2/2] Running Compiled execution (dynamic=True)...")
    compiled_fn = torch.compile(model_forward, dynamic=True)
    
    try:
        _ = run_function(compiled_fn, raw_inputs)
        print("      [PASS] Compiled mode completed successfully. (Bug might be patched)")
    except Exception as e:
        if "invalid for input of size" in str(e):
            print("\n" + "="*80)
            print("[BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.")
            print("="*80)
            traceback.print_exc()
            print("="*80)
        else:
            print("\n      [FAIL] Compiled mode failed with an unexpected error:")
            traceback.print_exc()

if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When executing the model forward pass with torch.compile(dynamic=True), the code fails with a RuntimeError due to a garbage shape value being generated (e.g., shape '[93989549545813, 6741088537664025784, 24]' is invalid for input of size 288).

The eager mode execution completes successfully without any issues, but the compiled execution crashes during aot_autograd alias tracking / apply_view_meta_sequence.

Sorry, I tried many times but I couldn't further minimize the reproducible code.

Error logs

(torch-nightly) xyt19@Oasis:/tmp$ python bug.py
PyTorch Version: 2.13.0.dev20260521+cu130

[1/2] Running Eager execution...
      [PASS] Eager mode completed successfully.

[2/2] Running Compiled execution (dynamic=True)...

================================================================================
[BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.
================================================================================
Traceback (most recent call last):
  File "/tmp/bug.py", line 385, in main
    _ = run_function(compiled_fn, raw_inputs)
  File "/tmp/bug.py", line 359, in run_function
    output = fn(*inputs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1141, in compile_wrapper
    result = fn(*args, **kwargs)
  File "/tmp/bug.py", line 209, in model_forward
    def model_forward(l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_):
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1415, in _fn
    return fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1277, in forward
    return compiled_fn(full_args)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2551, in __call__
    return self.compiled_fn(*args, **kwargs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1126, in runtime_wrapper
    result = _codegen_runtime_wrapper(
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(runtime_wrapper_orchestration)", line 12, in _runtime_wrapper
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 968, in _replay_alias
    return _codegen_alias_fn(orig_inputs, fw_outs)
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/subclass_codegen.py:codegen(output_alias_wrapper)", line 16, in _alias_fn
  File "/home/xyt19/miniconda3/envs/torch-nightly/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 333, in gen_alias_from_base
    out = _functionalization.apply_view_meta_sequence(
RuntimeError: shape '[93989549545813, 6741088537664025784, 24]' is invalid for input of size 288
================================================================================

To Reproduce

Save the following code as test.py and run it:

<details><summary>Click to expand code</summary>
import torch
import copy
import traceback
import warnings

warnings.filterwarnings("ignore")

def process_patches(l_image_: torch.Tensor, l_kernel_bank_: torch.Tensor, l_selector_: torch.Tensor, l_residual_weight_: torch.Tensor, l_channel_scale_: torch.Tensor) -> torch.Tensor:
    patches = torch.nn.functional.unfold(l_image_, kernel_size=3, padding=1, stride=1)
    transpose = patches.transpose(1, 2)
    patches_1 = transpose.contiguous()
    selected = torch.index_select(l_kernel_bank_, dim=0, index=l_selector_)
    selected_1 = selected.view(2, 36, 4)
    bmm = torch.bmm(patches_1, selected_1)
    transpose_1 = bmm.transpose(1, 2)
    filtered = transpose_1.contiguous()
    filtered_1 = filtered.view(2, 4, 6, 6)
    sigmoid = torch.nn.functional.sigmoid(filtered_1)
    tanh = torch.nn.functional.tanh(filtered_1)
    gated = torch.mul(sigmoid, tanh)
    pooled = torch.nn.functional.avg_pool2d(gated, kernel_size=3, stride=1, padding=1)
    view_2 = l_residual_weight_.view(1, 4, 1, 1)
    residual = torch.mul(l_image_, view_2)
    add = torch.add(pooled, residual)
    view_3 = l_channel_scale_.view(1, 4, 1, 1)
    mul_2 = torch.mul(add, view_3)
    return mul_2

def attention_layer_1(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def attention_layer_2(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def attention_layer_3(l_x_: torch.Tensor, l_q_weight_: torch.Tensor, l_k_weight_: torch.Tensor, l_v_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(l_x_, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(l_x_, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(l_x_)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(l_x_, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def fusion_layer_1(l_x_: torch.Tensor, l_aux_tokens_: torch.Tensor, l_q_weight_: torch.Tensor, l_kv_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_merge_weight_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    merged = torch.cat([l_aux_tokens_, l_x_], dim=1)
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(merged, l_kv_weight_)
    kv = linear_1.view(2, 11, 2, 4, 6)
    permute = kv.permute(2, 0, 3, 1, 4)
    kv_1 = permute.contiguous()
    getitem = kv_1[0]
    transpose_1 = getitem.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_1)
    scores = torch.true_divide(matmul, 2.449489742783178)
    view_2 = l_rel_bias_.view(1, 4, 6, 11)
    scores_1 = torch.add(scores, view_2)
    view_3 = l_mask_.view(2, 1, 1, 11)
    scores_2 = scores_1.masked_fill(view_3, -10000.0)
    probs = torch.nn.functional.softmax(scores_2, dim=-1)
    getitem_1 = kv_1[1]
    matmul_1 = torch.matmul(probs, getitem_1)
    transpose_2 = matmul_1.transpose(1, 2)
    contiguous_1 = transpose_2.contiguous()
    context = contiguous_1.view(2, 6, 24)
    mean = merged.mean(dim=1, keepdim=True)
    pooled = mean.expand(-1, 6, -1)
    cat_1 = torch.cat([context, pooled], dim=-1)
    fused = torch.nn.functional.linear(cat_1, l_merge_weight_)
    gelu = torch.nn.functional.gelu(fused)
    projected = torch.nn.functional.linear(gelu, l_out_weight_)
    add_1 = torch.add(l_x_, projected)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def fusion_layer_2(l_x_: torch.Tensor, l_aux_tokens_: torch.Tensor, l_q_weight_: torch.Tensor, l_kv_weight_: torch.Tensor, l_rel_bias_: torch.Tensor, l_mask_: torch.Tensor, l_merge_weight_: torch.Tensor, l_out_weight_: torch.Tensor, l_ln_weight_: torch.Tensor, l_ln_bias_: torch.Tensor) -> torch.Tensor:
    merged = torch.cat([l_aux_tokens_, l_x_], dim=1)
    linear = torch.nn.functional.linear(l_x_, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(merged, l_kv_weight_)
    kv = linear_1.view(2, 11, 2, 4, 6)
    permute = kv.permute(2, 0, 3, 1, 4)
    kv_1 = permute.contiguous()
    getitem = kv_1[0]
    transpose_1 = getitem.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_1)
    scores = torch.true_divide(matmul, 2.449489742783178)
    view_2 = l_rel_bias_.view(1, 4, 6, 11)
    scores_1 = torch.add(scores, view_2)
    view_3 = l_mask_.view(2, 1, 1, 11)
    scores_2 = scores_1.masked_fill(view_3, -10000.0)
    probs = torch.nn.functional.softmax(scores_2, dim=-1)
    getitem_1 = kv_1[1]
    matmul_1 = torch.matmul(probs, getitem_1)
    transpose_2 = matmul_1.transpose(1, 2)
    contiguous_1 = transpose_2.contiguous()
    context = contiguous_1.view(2, 6, 24)
    mean = merged.mean(dim=1, keepdim=True)
    pooled = mean.expand(-1, 6, -1)
    cat_1 = torch.cat([context, pooled], dim=-1)
    fused = torch.nn.functional.linear(cat_1, l_merge_weight_)
    gelu = torch.nn.functional.gelu(fused)
    projected = torch.nn.functional.linear(gelu, l_out_weight_)
    add_1 = torch.add(l_x_, projected)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    return layer_norm

def model_forward(l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_):
    g5_call = fusion_layer_2(g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_)
    linear = torch.nn.functional.linear(g5_call, l_q_weight_)
    view = linear.view(2, 6, 4, 6)
    q = view.transpose(1, 2)
    linear_1 = torch.nn.functional.linear(g5_call, l_k_weight_)
    view_1 = linear_1.view(2, 6, 4, 6)
    k = view_1.transpose(1, 2)
    linear_2 = torch.nn.functional.linear(g5_call, l_v_weight_)
    view_2 = linear_2.view(2, 6, 4, 6)
    v = view_2.transpose(1, 2)
    transpose_3 = k.transpose(-2, -1)
    matmul = torch.matmul(q, transpose_3)
    scores = torch.true_divide(matmul, 2.449489742783178)
    arange = torch.arange(6)
    view_3 = arange.view(1, -1)
    arange_1 = torch.arange(6)
    view_4 = arange_1.view(-1, 1)
    band = torch.sub(view_3, view_4)
    abs_1 = band.abs()
    band_mask = torch.gt(abs_1, 4)
    view_5 = l_rel_bias_.view(1, 4, 6, 6)
    scores_1 = torch.add(scores, view_5)
    view_6 = l_mask_.view(2, 1, 1, 6)
    scores_2 = scores_1.masked_fill(view_6, -10000.0)
    view_7 = band_mask.view(1, 1, 6, 6)
    scores_3 = scores_2.masked_fill(view_7, -10000.0)
    probs = torch.nn.functional.softmax(scores_3, dim=-1)
    matmul_1 = torch.matmul(probs, v)
    transpose_4 = matmul_1.transpose(1, 2)
    contiguous = transpose_4.contiguous()
    context = contiguous.view(2, 6, 24)
    projected = torch.nn.functional.linear(context, l_out_weight_)
    sigmoid = torch.nn.functional.sigmoid(g5_call)
    gated = torch.mul(projected, sigmoid)
    add_1 = torch.add(g5_call, gated)
    layer_norm = torch.nn.functional.layer_norm(add_1, [24], l_ln_weight_, l_ln_bias_, eps=1e-05)
    g3_call = attention_layer_2(linear_1, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_)
    g6_call = attention_layer_3(linear_1, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_)
    g2_call = process_patches(scores_2, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_)
    g1_call = attention_layer_1(gated, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_)
    g4_call = fusion_layer_1(add_1, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_)
    return arange, arange_1, view_5, view_6, view_3, view_4, g5_call, band, linear, linear_1, linear_2, sigmoid, abs_1, view, view_1, g3_call, g6_call, view_2, band_mask, q, k, v, view_7, transpose_3, matmul, scores, scores_1, scores_2, scores_3, g2_call, probs, matmul_1, transpose_4, contiguous, context, projected, gated, add_1, g1_call, layer_norm, g4_call


def get_inputs():
    l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    l_mask_ = (torch.rand([2, 6]) > 0.5)
    l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g5_l_x_ = (torch.randn([2, 6, 24], dtype=torch.float32) * 0.1)
    g5_l_aux_tokens_ = (torch.randn([2, 5, 24], dtype=torch.float32) * 0.1)
    g5_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g5_l_kv_weight_ = (torch.randn([48, 24], dtype=torch.float32) * 0.1)
    g5_l_rel_bias_ = torch.zeros([4, 6, 11], dtype=torch.float32)
    g5_l_mask_ = (torch.rand([2, 11]) > 0.5)
    g5_l_merge_weight_ = (torch.randn([24, 48], dtype=torch.float32) * 0.1)
    g5_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g5_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g5_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g3_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g3_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g3_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g3_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g3_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g6_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g6_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g6_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g6_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g6_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g2_l_kernel_bank_ = (torch.randn([6, 36, 4], dtype=torch.float32) * 0.1)
    g2_l_selector_ = torch.zeros([2], dtype=torch.int64)
    g2_l_residual_weight_ = (torch.randn([4], dtype=torch.float32) * 0.1)
    g2_l_channel_scale_ = (torch.rand([4], dtype=torch.float32) * 0.1 + 1.0)
    g1_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_k_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_v_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_rel_bias_ = torch.zeros([4, 6, 6], dtype=torch.float32)
    g1_l_mask_ = (torch.rand([2, 6]) > 0.5)
    g1_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g1_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g1_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    g4_l_aux_tokens_ = (torch.randn([2, 5, 24], dtype=torch.float32) * 0.1)
    g4_l_q_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g4_l_kv_weight_ = (torch.randn([48, 24], dtype=torch.float32) * 0.1)
    g4_l_rel_bias_ = torch.zeros([4, 6, 11], dtype=torch.float32)
    g4_l_mask_ = (torch.rand([2, 11]) > 0.5)
    g4_l_merge_weight_ = (torch.randn([24, 48], dtype=torch.float32) * 0.1)
    g4_l_out_weight_ = (torch.randn([24, 24], dtype=torch.float32) * 0.1)
    g4_l_ln_weight_ = torch.ones([24], dtype=torch.float32)
    g4_l_ln_bias_ = torch.zeros([24], dtype=torch.float32)
    return (l_q_weight_, l_k_weight_, l_v_weight_, l_rel_bias_, l_mask_, l_out_weight_, l_ln_weight_, l_ln_bias_, g5_l_x_, g5_l_aux_tokens_, g5_l_q_weight_, g5_l_kv_weight_, g5_l_rel_bias_, g5_l_mask_, g5_l_merge_weight_, g5_l_out_weight_, g5_l_ln_weight_, g5_l_ln_bias_, g3_l_q_weight_, g3_l_k_weight_, g3_l_v_weight_, g3_l_rel_bias_, g3_l_mask_, g3_l_out_weight_, g3_l_ln_weight_, g3_l_ln_bias_, g6_l_q_weight_, g6_l_k_weight_, g6_l_v_weight_, g6_l_rel_bias_, g6_l_mask_, g6_l_out_weight_, g6_l_ln_weight_, g6_l_ln_bias_, g2_l_kernel_bank_, g2_l_selector_, g2_l_residual_weight_, g2_l_channel_scale_, g1_l_q_weight_, g1_l_k_weight_, g1_l_v_weight_, g1_l_rel_bias_, g1_l_mask_, g1_l_out_weight_, g1_l_ln_weight_, g1_l_ln_bias_, g4_l_aux_tokens_, g4_l_q_weight_, g4_l_kv_weight_, g4_l_rel_bias_, g4_l_mask_, g4_l_merge_weight_, g4_l_out_weight_, g4_l_ln_weight_, g4_l_ln_bias_)

def flatten_tensors(obj):
    if isinstance(obj, torch.Tensor):
        return [obj]
    if isinstance(obj, (list, tuple)):
        out = []
        for item in obj:
            out.extend(flatten_tensors(item))
        return out
    if isinstance(obj, dict):
        out = []
        for key in sorted(obj, key=str):
            out.extend(flatten_tensors(obj[key]))
        return out
    return []

def map_tensors(obj, fn):
    if isinstance(obj, torch.Tensor):
        return fn(obj)
    if isinstance(obj, tuple):
        return tuple(map_tensors(item, fn) for item in obj)
    if isinstance(obj, list):
        return [map_tensors(item, fn) for item in obj]
    return obj

def loss_from_output(obj):
    terms = []
    for tensor in flatten_tensors(obj):
        if tensor.is_floating_point():
            terms.append(tensor.float().sum())
    if not terms:
        return None
    loss = terms[0]
    for term in terms[1:]:
        loss = loss + term
    return loss

def prepare_inputs(raw_inputs):
    def prepare_tensor(t):
        if t.is_floating_point() or t.is_complex():
            t = t.detach().clone()
            t.requires_grad_(True)
        return t
    return map_tensors(copy.deepcopy(raw_inputs), prepare_tensor)

def run_function(fn, raw_inputs):
    inputs = prepare_inputs(raw_inputs)
    with torch.enable_grad():
        output = fn(*inputs)
        loss = loss_from_output(output)
        if loss is not None and getattr(loss, "requires_grad", False):
            loss.backward()
        return output

def main():
    print(f"PyTorch Version: {torch.__version__}")
    
    torch.manual_seed(2077)
    
    raw_inputs = get_inputs()
    
    print("\n[1/2] Running Eager execution...")
    try:
        _ = run_function(model_forward, raw_inputs)
        print("      [PASS] Eager mode completed successfully.")
    except Exception as e:
        print("      [FAIL] Eager mode failed:")
        traceback.print_exc()
        return

    print("\n[2/2] Running Compiled execution (dynamic=True)...")
    compiled_fn = torch.compile(model_forward, dynamic=True)
    
    try:
        _ = run_function(compiled_fn, raw_inputs)
        print("      [PASS] Compiled mode completed successfully. (Bug might be patched)")
    except Exception as e:
        if "invalid for input of size" in str(e):
            print("\n" + "="*80)
            print("[BUG REPRODUCED] Garbage Shape RuntimeError triggered successfully.")
            print("="*80)
            traceback.print_exc()
            print("="*80)
        else:
            print("\n      [FAIL] Compiled mode failed with an unexpected error:")
            traceback.print_exc()

if __name__ == "__main__":
    main()
</details>

Expected behavior

The torch.compile(dynamic=True) execution should finish successfully without throwing a shape-related runtime error, matching the eager mode behavior.

Versions

PyTorch version: 2.13.0.dev20260521+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: 18.1.3 (1ubuntu1) CMake version: version 3.28.3 Libc version: glibc-2.39

Python version: 3.10.20 (main, Mar 11 2026, 17:46:40) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.0.140 Nvidia driver version: 596.49 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_tensor_ir.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.21.1 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A ersions of relevant libraries: [pip3] numpy==2.2.6 [pip3] nvidia-cublas==13.1.1.3 [pip3] nvidia-cuda-cupti==13.0.85 [pip3] nvidia-cuda-nvrtc==13.0.88 [pip3] nvidia-cuda-runtime==13.0.96 [pip3] nvidia-cudnn-cu13==9.20.0.48 [pip3] nvidia-cufft==12.0.0.61 [pip3] nvidia-curand==10.4.0.35 [pip3] nvidia-cusolver==12.0.4.66 [pip3] nvidia-cusparse==12.6.3.3 [pip3] nvidia-cusparselt-cu13==0.8.1 [pip3] nvidia-nccl-cu13==2.29.7 [pip3] nvidia-nvjitlink==13.0.88 [pip3] nvidia-nvtx==13.0.85 [pip3] torch==2.13.0.dev20260521+cu130 [pip3] torchaudio==2.11.0.dev20260525+cu130 [pip3] torchvision==0.28.0.dev20260525+cu130 [pip3] triton==3.7.0+git88b227e2 [conda] numpy 2.2.6 pypi_0 pypi [conda] nvidia-cublas 13.1.1.3 pypi_0 pypi [conda] nvidia-cuda-cupti 13.0.85 pypi_0 pypi [conda] nvidia-cuda-nvrtc 13.0.88 pypi_0 pypi [conda] nvidia-cuda-runtime 13.0.96 pypi_0 pypi [conda] nvidia-cudnn-cu13 9.20.0.48 pypi_0 pypi [conda] nvidia-cufft 12.0.0.61 pypi_0 pypi [conda] nvidia-curand 10.4.0.35 pypi_0 pypi [conda] nvidia-cusolver 12.0.4.66 pypi_0 pypi [conda] nvidia-cusparse 12.6.3.3 pypi_0 pypi [conda] nvidia-cusparselt-cu13 0.8.1 pypi_0 pypi [conda] nvidia-nccl-cu13 2.29.7 pypi_0 pypi [conda] nvidia-nvjitlink 13.0.88 pypi_0 pypi [conda] nvidia-nvtx 13.0.85 pypi_0 pypi [conda] torch 2.13.0.dev20260521+cu130 pypi_0 pypi [conda] torchaudio 2.11.0.dev20260525+cu130 pypi_0 pypi [conda] torchvision 0.28.0.dev20260525+cu130 pypi_0 pypi [conda] triton 3.7.0+git88b227e2 pypi_0 pypi

cc @chauhang @penguinwu @ezyang @bobrenjc93 @aditvenk @laithsakka @bdhirsh @aorenste

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

The torch.compile(dynamic=True) execution should finish successfully without throwing a shape-related runtime error, matching the eager mode behavior.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [PT2] RuntimeError with garbage shape in apply_view_meta_sequence when using torch.compile(dynamic=True)