pytorch - 💡(How to fix) Fix `torch.compile` crashes when `torch.cat` uses `axis=` keyword — `sink_cat_after_pointwise` pass

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: sink_cat_after_pointwise.<locals>.cat_args() got an unexpected keyword argument 'axis'

Root Cause

In torch/_inductor/fx_passes/pre_grad.py (line 729–732):

def cat_args(tensors, dim=0):  # ← doesn't accept 'axis'
    return tensors, dim

tensors, dim = cat_args(*node.args, **node.kwargs)  # ← crash when kwargs={'axis': N}

torch.cat accepts both dim and axis as aliases (NumPy compatibility), but cat_args only accepts dim.

Code Example

import torch

def fn(x):
    a = x.sin()
    b = x.cos()
    c = torch.cat([a, b], axis=1)  # axis= instead of dim=
    return c.relu()  # pointwise op triggers sink_cat_after_pointwise

x = torch.randn(4, 8, device="cuda")

# Eager: works
print(fn(x).shape)  # torch.Size([4, 16])

# Compiled: crashes
torch.compile(fn)(x)
# BackendCompilerFailed: backend='inductor' raised:
# TypeError: sink_cat_after_pointwise.<locals>.cat_args() got an unexpected keyword argument 'axis'

---

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: sink_cat_after_pointwise.<locals>.cat_args() got an unexpected keyword argument 'axis'

---

def cat_args(tensors, dim=0):  # ← doesn't accept 'axis'
    return tensors, dim

tensors, dim = cat_args(*node.args, **node.kwargs)  # ← crash when kwargs={'axis': N}

---

def cat_args(tensors, dim=0, axis=None):
    if axis is not None:
        dim = axis
    return tensors, dim
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

torch.compile crashes with TypeError when torch.cat is called with the NumPy-compatible axis= keyword (instead of dim=) and the result is consumed by a pointwise unary op (e.g., relu, tanh).

The crash occurs in the Inductor pre-grad pass sink_cat_after_pointwise, which defines a helper cat_args(tensors, dim=0) that only accepts dim. When FX captures torch.cat([...], axis=N), the node's kwargs contain {'axis': N}, which is passed via **node.kwargs and raises TypeError.

Reproducer

import torch

def fn(x):
    a = x.sin()
    b = x.cos()
    c = torch.cat([a, b], axis=1)  # axis= instead of dim=
    return c.relu()  # pointwise op triggers sink_cat_after_pointwise

x = torch.randn(4, 8, device="cuda")

# Eager: works
print(fn(x).shape)  # torch.Size([4, 16])

# Compiled: crashes
torch.compile(fn)(x)
# BackendCompilerFailed: backend='inductor' raised:
# TypeError: sink_cat_after_pointwise.<locals>.cat_args() got an unexpected keyword argument 'axis'

Using dim=1 instead of axis=1 works correctly. The crash only occurs when a pointwise op follows the cat (which triggers the sink_cat_after_pointwise optimization pass).

Error message

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: sink_cat_after_pointwise.<locals>.cat_args() got an unexpected keyword argument 'axis'

Root cause

In torch/_inductor/fx_passes/pre_grad.py (line 729–732):

def cat_args(tensors, dim=0):  # ← doesn't accept 'axis'
    return tensors, dim

tensors, dim = cat_args(*node.args, **node.kwargs)  # ← crash when kwargs={'axis': N}

torch.cat accepts both dim and axis as aliases (NumPy compatibility), but cat_args only accepts dim.

Suggested fix

def cat_args(tensors, dim=0, axis=None):
    if axis is not None:
        dim = axis
    return tensors, dim

Versions

Versions

  • PyTorch: 2.13.0.dev20260520+cu126
  • Triton: 3.7.0
  • CUDA: 12.6
  • GPU: Tesla T4 (sm_75)
  • Python: 3.11

cc @mruberry @rgommers @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING