pytorch - 💡(How to fix) Fix [inductor] cat lowering AssertionError with 1-D empty tensor and negative dim [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

import torch

def fn(cache, new_keys): return torch.cat([cache, new_keys], dim=-2)

cache = torch.tensor([], dtype=torch.bfloat16, device="cuda") # shape [0] new_keys = torch.randn(1, 8, 0, 78, dtype=torch.bfloat16, device="cuda") # shape [1, 8, 0, 78]

Eager: works

print(fn(cache, new_keys).shape) # torch.Size([1, 8, 0, 78])

Inductor: fails

torch.compile(fn, backend="inductor")(cache, new_keys)

InductorError: LoweringException: AssertionError:

target: aten.cat.default

args[1]: -2

Root Cause

In torch/_inductor/lowering.py, the cat lowering:

dim = _validate_dim(inputs[0], dim, 0)

_validate_dim normalizes negative dim against inputs[0].ndim. When inputs[0] is 1-D (ndim=1) and dim=-2, normalization gives -2 + 1 = -1, failing assert 0 <= dim < ndim.

ATen's legacy_cat_wrap_dim (aten/src/ATen/WrapDimUtils.h:133) skips tensors where dim() == 1 && numel() == 0, using the first non-skipped tensor for wrapping.

Fix Action

Fixed

Code Example

import torch

def fn(cache, new_keys):
    return torch.cat([cache, new_keys], dim=-2)

cache = torch.tensor([], dtype=torch.bfloat16, device="cuda")              # shape [0]
new_keys = torch.randn(1, 8, 0, 78, dtype=torch.bfloat16, device="cuda")  # shape [1, 8, 0, 78]

# Eager: works
print(fn(cache, new_keys).shape)  # torch.Size([1, 8, 0, 78])

# Inductor: fails
torch.compile(fn, backend="inductor")(cache, new_keys)
# InductorError: LoweringException: AssertionError:
#   target: aten.cat.default
#   args[1]: -2

---

dim = _validate_dim(inputs[0], dim, 0)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Inductor's lowering for aten.cat.default raises AssertionError when a 1-D empty tensor (e.g. torch.tensor([])) appears among the inputs and dim is negative with magnitude exceeding that tensor's rank.

ATen handles this via cat_should_skip_tensor: 1-D tensors with numel==0 are skipped during dimension validation and concatenation. Inductor's cat lowering does not replicate this skip, so _validate_dim always uses inputs[0] as the reference for normalizing negative dim.

To reproduce

import torch

def fn(cache, new_keys):
    return torch.cat([cache, new_keys], dim=-2)

cache = torch.tensor([], dtype=torch.bfloat16, device="cuda")              # shape [0]
new_keys = torch.randn(1, 8, 0, 78, dtype=torch.bfloat16, device="cuda")  # shape [1, 8, 0, 78]

# Eager: works
print(fn(cache, new_keys).shape)  # torch.Size([1, 8, 0, 78])

# Inductor: fails
torch.compile(fn, backend="inductor")(cache, new_keys)
# InductorError: LoweringException: AssertionError:
#   target: aten.cat.default
#   args[1]: -2

Real-world trigger

HF Transformers DynamicCache initializes the KV cache as torch.tensor([], dtype=..., device=...) (1-D, shape [0]), then concatenates via torch.cat([self.keys, key_states], dim=-2). On the first forward pass when both the cache and incoming keys are empty (seq_len=0), this reaches Inductor without being constant-folded, triggering the assertion. Observed on FrontiersMind/Nandi-Mini-600M-Early-Checkpoint but likely affects any model using this cache pattern under compile.

Root cause

In torch/_inductor/lowering.py, the cat lowering:

dim = _validate_dim(inputs[0], dim, 0)

_validate_dim normalizes negative dim against inputs[0].ndim. When inputs[0] is 1-D (ndim=1) and dim=-2, normalization gives -2 + 1 = -1, failing assert 0 <= dim < ndim.

ATen's legacy_cat_wrap_dim (aten/src/ATen/WrapDimUtils.h:133) skips tensors where dim() == 1 && numel() == 0, using the first non-skipped tensor for wrapping.

Fix sketch

Mirror ATen's cat_should_skip_tensor in the Inductor lowering: before calling _validate_dim, filter out 1-D inputs whose size is statically known to be 0. Use the first remaining input for dimension validation. Pass only non-skipped inputs to downstream concat logic (ConcatKernel / pointwise_cat).

I have a working local fix and will open a PR.

Versions

  • PyTorch: nightly (built from source, commit 4003d0c on main)
  • transformers: 5.9.0

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING