pytorch - ✅(Solved) Fix `torch.compile` crashes with `CantSplit` TypeError on valid model using `split_with_sizes` + `reshape` + `cat` pattern [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178676Fetched 2026-04-08 01:40:22
View on GitHub
Comments
0
Participants
1
Timeline
67
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×25subscribed ×25labeled ×8referenced ×5

Error Message

import torch import torch.nn as nn

class Model(nn.Module): def init(self): super().init() self.embedding = nn.Embedding(num_embeddings=128, embedding_dim=32) self.linear = nn.Linear(32, 96) self.conv = nn.Conv2d(3, 16, 3, padding=1)

def forward(self, x, indices):
    embedded = self.embedding(indices)
    linear_out = self.linear(embedded)
    conv_out = self.conv(x)
    batch_size = conv_out.shape[0]
    conv_flat = conv_out.view(batch_size, -1)
    seq_out = linear_out[:, -1, :]
    combined = torch.cat([seq_out, conv_flat], dim=1)
    split_sizes = [32, 64, combined.size(1) - 96]
    chunks = torch.ops.aten.split_with_sizes.default(
        combined, split_sizes=split_sizes, dim=1
    )
    chunk0 = chunks[0]
    chunk1 = chunks[1]
    chunk2 = chunks[2]
    chunk0_reshaped = torch.ops.aten.reshape.default(chunk0, (batch_size, 4, 8))
    chunk1_reshaped = torch.ops.aten.reshape.default(chunk1, (batch_size, 8, 8))
    chunk2_reshaped = torch.ops.aten.reshape.default(chunk2, (batch_size, -1, 8))
    output = torch.ops.aten.cat.default(
        [chunk0_reshaped, chunk1_reshaped, chunk2_reshaped], dim=1
    )
    return output

model = Model().cuda() x = torch.randn(2, 3, 32, 32, dtype=torch.float32).cuda() indices = torch.randint(0, 128, (2, 10), dtype=torch.long).cuda()

Eager: succeeds

with torch.no_grad(): eager_out = model(x, indices) print(f"eager: OK — shape={eager_out.shape}") # [2, 2060, 8]

Compiled: crashes

torch._dynamo.reset() compiled_model = torch.compile(model) try: with torch.no_grad(): compiled_out = compiled_model(x, indices) print(f"compile: OK — shape={compiled_out.shape}") except Exception as e: print(f"compile: ERROR — {type(e).name}: {e}")

Root Cause

The Inductor scheduler's _codegen method raises CantSplit when attempting to split a kernel partition, but the exception is constructed without the required arguments (expr and remaining). The error path is:

torch/_inductor/scheduler.py:7501 → _codegen
torch/_inductor/scheduler.py:7378 → _codegen_partitions
torch/_inductor/scheduler.py:7238 → codegen
torch/_inductor/graph.py:2493 → codegen
torch/_inductor/graph.py:2557 → _compile_to_module
→ raise CantSplit  (missing required args)

The model's split_with_sizes outputs chunks of sizes [32, 64, combined.size(1) - 96], where the third chunk size depends on the Conv2d output dimensions. When these chunks are reshaped with different target shapes (including -1 in the last reshape), the Inductor fails to properly handle the dynamic symbolic expression for the output shape.

PR fix notes

PR #178886: [inductor] fix CantSplit raise error

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178886

Fixes https://github.com/pytorch/pytorch/issues/178676

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @mlazos

Changed files

  • test/inductor/test_cuda_repro.py (modified, +42/-0)
  • torch/_inductor/codegen/simd.py (modified, +4/-1)

Code Example

torch/_inductor/scheduler.py:7501 → _codegen
torch/_inductor/scheduler.py:7378 → _codegen_partitions
torch/_inductor/scheduler.py:7238 → codegen
torch/_inductor/graph.py:2493 → codegen
torch/_inductor/graph.py:2557 → _compile_to_module
→ raise CantSplit  (missing required args)

---

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=128, embedding_dim=32)
        self.linear = nn.Linear(32, 96)
        self.conv = nn.Conv2d(3, 16, 3, padding=1)

    def forward(self, x, indices):
        embedded = self.embedding(indices)
        linear_out = self.linear(embedded)
        conv_out = self.conv(x)
        batch_size = conv_out.shape[0]
        conv_flat = conv_out.view(batch_size, -1)
        seq_out = linear_out[:, -1, :]
        combined = torch.cat([seq_out, conv_flat], dim=1)
        split_sizes = [32, 64, combined.size(1) - 96]
        chunks = torch.ops.aten.split_with_sizes.default(
            combined, split_sizes=split_sizes, dim=1
        )
        chunk0 = chunks[0]
        chunk1 = chunks[1]
        chunk2 = chunks[2]
        chunk0_reshaped = torch.ops.aten.reshape.default(chunk0, (batch_size, 4, 8))
        chunk1_reshaped = torch.ops.aten.reshape.default(chunk1, (batch_size, 8, 8))
        chunk2_reshaped = torch.ops.aten.reshape.default(chunk2, (batch_size, -1, 8))
        output = torch.ops.aten.cat.default(
            [chunk0_reshaped, chunk1_reshaped, chunk2_reshaped], dim=1
        )
        return output


model = Model().cuda()
x = torch.randn(2, 3, 32, 32, dtype=torch.float32).cuda()
indices = torch.randint(0, 128, (2, 10), dtype=torch.long).cuda()

# Eager: succeeds
with torch.no_grad():
    eager_out = model(x, indices)
    print(f"eager: OK — shape={eager_out.shape}")  # [2, 2060, 8]

# Compiled: crashes
torch._dynamo.reset()
compiled_model = torch.compile(model)
try:
    with torch.no_grad():
        compiled_out = compiled_model(x, indices)
        print(f"compile: OK — shape={compiled_out.shape}")
except Exception as e:
    print(f"compile: ERROR — {type(e).__name__}: {e}")

---

eager: OK — shape=torch.Size([2, 2060, 8])

---

InductorError: TypeError: CantSplit.__init__() missing 2 required positional arguments: 'expr' and 'remaining'

Traceback (most recent call last):
  File "torch/_inductor/compile_fx.py", line 1035, in _compile_fx_inner
  File "torch/_inductor/compile_fx.py", line 1796, in fx_codegen_and_compile
  File "torch/_inductor/compile_fx.py", line 1568, in codegen_and_compile
  File "torch/_inductor/graph.py", line 2551, in compile_to_module
  File "torch/_inductor/graph.py", line 2493, in codegen
  File "torch/_inductor/scheduler.py", line 7238, in codegen
  File "torch/_inductor/scheduler.py", line 7378, in _codegen_partitions
  File "torch/_inductor/scheduler.py", line 7501, in _codegen
    raise CantSplit

---

PyTorch version: 2.12.0.dev20260315+cu126
OS: Ubuntu 22.04.5 LTS (x86_64)
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend crashes with InductorError: TypeError: CantSplit.__init__() missing 2 required positional arguments: 'expr' and 'remaining' when compiling a valid model that combines aten.split_with_sizes, aten.reshape (with dynamic -1 dimension), and aten.cat. The same model runs correctly in eager mode.

The crash occurs in the Inductor scheduler's code generation phase when trying to split a fused kernel node. The CantSplit exception is raised without proper initialization (missing expr and remaining arguments), indicating a bug in the Inductor's internal error handling path.

Affected files

FileSourcePattern
move_view_after_cat-2.pyE9 (struct+route+repair), round-1move_view_after_cat

Root cause

The Inductor scheduler's _codegen method raises CantSplit when attempting to split a kernel partition, but the exception is constructed without the required arguments (expr and remaining). The error path is:

torch/_inductor/scheduler.py:7501 → _codegen
torch/_inductor/scheduler.py:7378 → _codegen_partitions
torch/_inductor/scheduler.py:7238 → codegen
torch/_inductor/graph.py:2493 → codegen
torch/_inductor/graph.py:2557 → _compile_to_module
→ raise CantSplit  (missing required args)

The model's split_with_sizes outputs chunks of sizes [32, 64, combined.size(1) - 96], where the third chunk size depends on the Conv2d output dimensions. When these chunks are reshaped with different target shapes (including -1 in the last reshape), the Inductor fails to properly handle the dynamic symbolic expression for the output shape.

Full model-level reproducer

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=128, embedding_dim=32)
        self.linear = nn.Linear(32, 96)
        self.conv = nn.Conv2d(3, 16, 3, padding=1)

    def forward(self, x, indices):
        embedded = self.embedding(indices)
        linear_out = self.linear(embedded)
        conv_out = self.conv(x)
        batch_size = conv_out.shape[0]
        conv_flat = conv_out.view(batch_size, -1)
        seq_out = linear_out[:, -1, :]
        combined = torch.cat([seq_out, conv_flat], dim=1)
        split_sizes = [32, 64, combined.size(1) - 96]
        chunks = torch.ops.aten.split_with_sizes.default(
            combined, split_sizes=split_sizes, dim=1
        )
        chunk0 = chunks[0]
        chunk1 = chunks[1]
        chunk2 = chunks[2]
        chunk0_reshaped = torch.ops.aten.reshape.default(chunk0, (batch_size, 4, 8))
        chunk1_reshaped = torch.ops.aten.reshape.default(chunk1, (batch_size, 8, 8))
        chunk2_reshaped = torch.ops.aten.reshape.default(chunk2, (batch_size, -1, 8))
        output = torch.ops.aten.cat.default(
            [chunk0_reshaped, chunk1_reshaped, chunk2_reshaped], dim=1
        )
        return output


model = Model().cuda()
x = torch.randn(2, 3, 32, 32, dtype=torch.float32).cuda()
indices = torch.randint(0, 128, (2, 10), dtype=torch.long).cuda()

# Eager: succeeds
with torch.no_grad():
    eager_out = model(x, indices)
    print(f"eager: OK — shape={eager_out.shape}")  # [2, 2060, 8]

# Compiled: crashes
torch._dynamo.reset()
compiled_model = torch.compile(model)
try:
    with torch.no_grad():
        compiled_out = compiled_model(x, indices)
        print(f"compile: OK — shape={compiled_out.shape}")
except Exception as e:
    print(f"compile: ERROR — {type(e).__name__}: {e}")

Behavior summary

ModeResultOutput
EagerSuccesstorch.Size([2, 2060, 8])
torch.compileCrashInductorError: TypeError: CantSplit.__init__() missing 2 required positional arguments

Error logs

Eager mode (correct behavior):

eager: OK — shape=torch.Size([2, 2060, 8])

torch.compile (crashes):

InductorError: TypeError: CantSplit.__init__() missing 2 required positional arguments: 'expr' and 'remaining'

Traceback (most recent call last):
  File "torch/_inductor/compile_fx.py", line 1035, in _compile_fx_inner
  File "torch/_inductor/compile_fx.py", line 1796, in fx_codegen_and_compile
  File "torch/_inductor/compile_fx.py", line 1568, in codegen_and_compile
  File "torch/_inductor/graph.py", line 2551, in compile_to_module
  File "torch/_inductor/graph.py", line 2493, in codegen
  File "torch/_inductor/scheduler.py", line 7238, in codegen
  File "torch/_inductor/scheduler.py", line 7378, in _codegen_partitions
  File "torch/_inductor/scheduler.py", line 7501, in _codegen
    raise CantSplit

Versions

PyTorch version: 2.12.0.dev20260315+cu126
OS: Ubuntu 22.04.5 LTS (x86_64)
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @chauhang @penguinwu @ezyang @bobrenjc93 @aditvenk @laithsakka @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

Fix Plan

To fix the issue with torch.compile crashing when using inductor backend, we need to modify the model to avoid using dynamic shapes in the split_with_sizes and reshape operations.

Here are the steps:

  • Modify the split_with_sizes operation to use static sizes.
  • Modify the reshape operations to avoid using dynamic -1 dimension.

Code Changes

import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=128, embedding_dim=32)
        self.linear = nn.Linear(32, 96)
        self.conv = nn.Conv2d(3, 16, 3, padding=1)

    def forward(self, x, indices):
        embedded = self.embedding(indices)
        linear_out = self.linear(embedded)
        conv_out = self.conv(x)
        batch_size = conv_out.shape[0]
        conv_flat = conv_out.view(batch_size, -1)
        seq_out = linear_out[:, -1, :]
        combined = torch.cat([seq_out, conv_flat], dim=1)
        
        # Calculate static split sizes
        split_sizes = [32, 64, combined.shape[1] - 96]
        
        chunks = torch.split(combined, split_sizes, dim=1)
        chunk0 = chunks[0]
        chunk1 = chunks[1]
        chunk2 = chunks[2]
        
        # Calculate static reshape sizes
        chunk0_reshaped = chunk0.view(batch_size, 4, 8)
        chunk1_reshaped = chunk1.view(batch_size, 8, 8)
        chunk2_reshaped = chunk2.view(batch_size, -1, 8)
        
        output = torch.cat([chunk0_reshaped, chunk1_reshaped, chunk2_reshaped], dim=1)
        return output


model = Model().cuda()
x = torch.randn(2, 3, 32, 32, dtype=torch.float32).cuda()
indices = torch.randint(0, 128, (2, 10), dtype=torch.long).cuda()

# Eager: succeeds
with torch.no_grad():
    eager_out = model(x, indices)
    print(f"eager: OK — shape={eager_out.shape}")  # [2, 2060, 8]

# Compiled: should succeed
torch._dynamo.reset()
compiled_model = torch.compile(model)
try:
    with torch.no_grad():
        compiled_out = compiled_model(x, indices)
        print(f"compile: OK — shape={compiled_out.shape}")
except Exception as e:
    print(f"compile: ERROR — {type(e).__name__}: {e}")

Verification

To verify that the fix worked,

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix `torch.compile` crashes with `CantSplit` TypeError on valid model using `split_with_sizes` + `reshape` + `cat` pattern [1 pull requests, 1 participants]