pytorch - 💡(How to fix) Fix torch.compile wraps IndexError as non-IndexError exceptions, breaking except IndexError and fallback handlers [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#184364Fetched 2026-05-20 03:39:04
View on GitHub
Comments
1
Participants
2
Timeline
7
Reactions
0
Participants
Timeline (top)
mentioned ×2subscribed ×2closed ×1commented ×1

Error Message

import torch import torch._dynamo

print("torch:", torch.version) print("torch_cuda:", torch.version.cuda) print("cuda_available:", torch.cuda.is_available()) print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

device = "cuda" if torch.cuda.is_available() else "cpu" x = torch.randn(4, device=device)

def sync_if_needed(): if torch.cuda.is_available(): torch.cuda.synchronize()

def classify_outer_call(fn): try: fn(x) sync_if_needed() return ("ok", None, None) except IndexError as e: return ("caught_IndexError", type(e).name, str(e).splitlines()[0]) except Exception as e: return ("caught_other", type(e).name, str(e).splitlines()[0])

1. Caller-side except IndexError around compiled call.

def bad_softmax(t): return torch.softmax(t, dim=10)

torch._dynamo.reset() eager = classify_outer_call(bad_softmax)

torch._dynamo.reset() compiled_bad_softmax = torch.compile(bad_softmax, backend="inductor") compiled = classify_outer_call(compiled_bad_softmax)

print("\n===== minimal_outer_except_indexerror =====") print("device=", device) print("eager=", eager) print("compile=", compiled) print("VERDICT=" + ("FAIL_contract_broken" if eager[0] == "caught_IndexError" and compiled[0] != "caught_IndexError" else "PASS"))

2. Fallback inside compiled function.

def safe_softmax(t): try: return torch.softmax(t, dim=10) except IndexError: return torch.softmax(t, dim=-1)

print("\n===== inner_safe_softmax_fallback =====") eager_exc = None compile_exc = None try: torch._dynamo.reset() y = safe_softmax(x) sync_if_needed() print("eager_exception= None") print("eager_ok=", tuple(y.shape)) except Exception as e: eager_exc = type(e).name print("eager_exception=", eager_exc, str(e).splitlines()[0])

try: torch._dynamo.reset() compiled_safe_softmax = torch.compile(safe_softmax, backend="inductor") y = compiled_safe_softmax(x) sync_if_needed() print("compile_exception= None") print("compile_ok=", tuple(y.shape)) except Exception as e: compile_exc = type(e).name print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_inner_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))

3. Same fallback pattern inside nn.Module.forward.

class SafeSoftmaxModule(torch.nn.Module): def forward(self, t): try: return torch.softmax(t, dim=10) except IndexError: return torch.softmax(t, dim=-1)

print("\n===== nn_module_forward_fallback =====") eager_exc = None compile_exc = None try: m = SafeSoftmaxModule().to(device) y = m(x) sync_if_needed() print("eager_exception= None") print("eager_ok=", tuple(y.shape)) except Exception as e: eager_exc = type(e).name print("eager_exception=", eager_exc, str(e).splitlines()[0])

try: torch._dynamo.reset() cm = torch.compile(m, backend="inductor") y = cm(x) sync_if_needed() print("compile_exception= None") print("compile_ok=", tuple(y.shape)) except Exception as e: compile_exc = type(e).name print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_nn_module_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))

4. Catch IndexError and re-raise a custom application exception.

class MyError(Exception): pass

def reraise_custom(t): try: return torch.softmax(t, dim=10) except IndexError as e: raise MyError("softmax fallback failed") from e

print("\n===== reraise_custom_exception =====") eager_exc = None compile_exc = None try: reraise_custom(x) sync_if_needed() print("eager_exception= None") except Exception as e: eager_exc = type(e).name print("eager_exception=", eager_exc)

try: torch._dynamo.reset() cr = torch.compile(reraise_custom, backend="inductor") cr(x) sync_if_needed() print("compile_exception= None") except Exception as e: compile_exc = type(e).name print("compile_exception=", compile_exc)

print("VERDICT=" + ("FAIL_reraise_contract_broken" if eager_exc == "MyError" and compile_exc != "MyError" else "PASS"))

Root Cause

This can break real application safety/fallback code. A common pattern is:

def safe_apply(fn, x, dim):
    try:
        return fn(x, dim=dim)
    except IndexError:
        return fn(x, dim=-1)

In eager mode this pattern works. Under torch.compile, the fallback path can be skipped and the application sees a non-IndexError compiler/Dynamo wrapper instead.

Code Example

import torch
import torch._dynamo

print("torch:", torch.__version__)
print("torch_cuda:", torch.version.cuda)
print("cuda_available:", torch.cuda.is_available())
print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(4, device=device)


def sync_if_needed():
    if torch.cuda.is_available():
        torch.cuda.synchronize()


def classify_outer_call(fn):
    try:
        fn(x)
        sync_if_needed()
        return ("ok", None, None)
    except IndexError as e:
        return ("caught_IndexError", type(e).__name__, str(e).splitlines()[0])
    except Exception as e:
        return ("caught_other", type(e).__name__, str(e).splitlines()[0])


# 1. Caller-side except IndexError around compiled call.
def bad_softmax(t):
    return torch.softmax(t, dim=10)


torch._dynamo.reset()
eager = classify_outer_call(bad_softmax)

torch._dynamo.reset()
compiled_bad_softmax = torch.compile(bad_softmax, backend="inductor")
compiled = classify_outer_call(compiled_bad_softmax)

print("\n===== minimal_outer_except_indexerror =====")
print("device=", device)
print("eager=", eager)
print("compile=", compiled)
print("VERDICT=" + ("FAIL_contract_broken" if eager[0] == "caught_IndexError" and compiled[0] != "caught_IndexError" else "PASS"))


# 2. Fallback inside compiled function.
def safe_softmax(t):
    try:
        return torch.softmax(t, dim=10)
    except IndexError:
        return torch.softmax(t, dim=-1)


print("\n===== inner_safe_softmax_fallback =====")
eager_exc = None
compile_exc = None
try:
    torch._dynamo.reset()
    y = safe_softmax(x)
    sync_if_needed()
    print("eager_exception= None")
    print("eager_ok=", tuple(y.shape))
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc, str(e).splitlines()[0])

try:
    torch._dynamo.reset()
    compiled_safe_softmax = torch.compile(safe_softmax, backend="inductor")
    y = compiled_safe_softmax(x)
    sync_if_needed()
    print("compile_exception= None")
    print("compile_ok=", tuple(y.shape))
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_inner_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))


# 3. Same fallback pattern inside nn.Module.forward.
class SafeSoftmaxModule(torch.nn.Module):
    def forward(self, t):
        try:
            return torch.softmax(t, dim=10)
        except IndexError:
            return torch.softmax(t, dim=-1)


print("\n===== nn_module_forward_fallback =====")
eager_exc = None
compile_exc = None
try:
    m = SafeSoftmaxModule().to(device)
    y = m(x)
    sync_if_needed()
    print("eager_exception= None")
    print("eager_ok=", tuple(y.shape))
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc, str(e).splitlines()[0])

try:
    torch._dynamo.reset()
    cm = torch.compile(m, backend="inductor")
    y = cm(x)
    sync_if_needed()
    print("compile_exception= None")
    print("compile_ok=", tuple(y.shape))
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_nn_module_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))


# 4. Catch IndexError and re-raise a custom application exception.
class MyError(Exception):
    pass


def reraise_custom(t):
    try:
        return torch.softmax(t, dim=10)
    except IndexError as e:
        raise MyError("softmax fallback failed") from e


print("\n===== reraise_custom_exception =====")
eager_exc = None
compile_exc = None
try:
    reraise_custom(x)
    sync_if_needed()
    print("eager_exception= None")
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc)

try:
    torch._dynamo.reset()
    cr = torch.compile(reraise_custom, backend="inductor")
    cr(x)
    sync_if_needed()
    print("compile_exception= None")
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc)

print("VERDICT=" + ("FAIL_reraise_contract_broken" if eager_exc == "MyError" and compile_exc != "MyError" else "PASS"))

---

PyTorch: 2.10.0+cu128
CUDA: 12.8
GPU: Tesla T4
contract_failures: 4
passes: 0
errors_or_timeouts: 0

---

===== minimal_outer_except_indexerror =====
returncode: 0
--- stdout ---
device= cuda
eager= ('caught_IndexError', 'IndexError', 'Dimension out of range (expected to be in range of [-1, 0], but got 10)')
compile= ('caught_other', 'BackendCompilerFailed', "backend='inductor' raised:")
VERDICT=FAIL_contract_broken

===== inner_safe_softmax_fallback =====
returncode: 0
--- stdout ---
compile_exception= BackendCompilerFailed backend='inductor' raised:
VERDICT=FAIL_inner_try_except_broken

===== nn_module_forward_fallback =====
returncode: 0
--- stdout ---
compile_exception= BackendCompilerFailed backend='inductor' raised:
VERDICT=FAIL_nn_module_try_except_broken

===== reraise_custom_exception =====
returncode: 0
--- stdout ---
eager_exception= MyError
compile_exception= BackendCompilerFailed
VERDICT=FAIL_reraise_contract_broken

=== SUMMARY ===
contract_failures: 4
passes: 0
errors_or_timeouts: 0
expected_if_bug_reproduces: contract_failures >= 1 and errors_or_timeouts == 0
SystemExit: 0

---

softmax_dim_oor    eager: IndexError -> compile: BackendCompilerFailed

tensor_size_neg    eager: IndexError -> compile: InternalTorchDynamoError
squeeze_oor_dim    eager: IndexError -> compile: TorchRuntimeError
unsqueeze_oor_dim  eager: IndexError -> compile: TorchRuntimeError
amax_oor_dim       eager: IndexError -> compile: TorchRuntimeError
argmax_oor_dim     eager: IndexError -> compile: TorchRuntimeError
flip_oor_dim       eager: IndexError -> compile: TorchRuntimeError
cat_oor_dim        eager: IndexError -> compile: TorchRuntimeError
stack_oor_dim      eager: IndexError -> compile: TorchRuntimeError

---

def safe_apply(fn, x, dim):
    try:
        return fn(x, dim=dim)
    except IndexError:
        return fn(x, dim=-1)

---

Python: 3.12.13
PyTorch: 2.10.0+cu128
CUDA: 12.8
GPU: Tesla T4
Source: Colab re-test
Result: 4/4 safe confirmation cases failed the exception-contract check; 9/9 original IndexError cases also reproduced.

---

PyTorch: 2.9.1+cu128
CUDA: 12.8
GPU: NVIDIA L4 24GB
Source: original fuzzing artifact / generated findings
Result: same IndexError contract pattern observed across the 9 original cases.
RAW_BUFFERClick to expand / collapse

Describe the bug

torch.compile does not preserve the user-visible exception contract for IndexError raised by out-of-range tensor dimension operations.

In eager mode, the same code raises IndexError, so normal Python handlers such as except IndexError: work.

In compiled mode, the underlying IndexError may be raised internally during fake tensor/meta execution, but the user-visible exception escapes as a non-IndexError wrapper exception such as BackendCompilerFailed, TorchRuntimeError, or InternalTorchDynamoError. This means existing except IndexError handlers are skipped.

This breaks common defensive/fallback patterns, not just error formatting:

  • caller-side except IndexError around a compiled call is missed;
  • fallback logic inside a compiled function is skipped;
  • nn.Module.forward fallback logic is skipped;
  • code that catches IndexError and re-raises a custom application exception changes behavior.

I am not reporting the internal aten.amax.default lowering detail itself. In the softmax(dim=10) case, stderr shows aten.amax.default because softmax lowering uses a reduction internally. The user-visible issue is that eager raises IndexError, while compiled code leaks a non-IndexError wrapper, so Python exception handling semantics change.

Reproducer covering four affected patterns

import torch
import torch._dynamo

print("torch:", torch.__version__)
print("torch_cuda:", torch.version.cuda)
print("cuda_available:", torch.cuda.is_available())
print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

device = "cuda" if torch.cuda.is_available() else "cpu"
x = torch.randn(4, device=device)


def sync_if_needed():
    if torch.cuda.is_available():
        torch.cuda.synchronize()


def classify_outer_call(fn):
    try:
        fn(x)
        sync_if_needed()
        return ("ok", None, None)
    except IndexError as e:
        return ("caught_IndexError", type(e).__name__, str(e).splitlines()[0])
    except Exception as e:
        return ("caught_other", type(e).__name__, str(e).splitlines()[0])


# 1. Caller-side except IndexError around compiled call.
def bad_softmax(t):
    return torch.softmax(t, dim=10)


torch._dynamo.reset()
eager = classify_outer_call(bad_softmax)

torch._dynamo.reset()
compiled_bad_softmax = torch.compile(bad_softmax, backend="inductor")
compiled = classify_outer_call(compiled_bad_softmax)

print("\n===== minimal_outer_except_indexerror =====")
print("device=", device)
print("eager=", eager)
print("compile=", compiled)
print("VERDICT=" + ("FAIL_contract_broken" if eager[0] == "caught_IndexError" and compiled[0] != "caught_IndexError" else "PASS"))


# 2. Fallback inside compiled function.
def safe_softmax(t):
    try:
        return torch.softmax(t, dim=10)
    except IndexError:
        return torch.softmax(t, dim=-1)


print("\n===== inner_safe_softmax_fallback =====")
eager_exc = None
compile_exc = None
try:
    torch._dynamo.reset()
    y = safe_softmax(x)
    sync_if_needed()
    print("eager_exception= None")
    print("eager_ok=", tuple(y.shape))
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc, str(e).splitlines()[0])

try:
    torch._dynamo.reset()
    compiled_safe_softmax = torch.compile(safe_softmax, backend="inductor")
    y = compiled_safe_softmax(x)
    sync_if_needed()
    print("compile_exception= None")
    print("compile_ok=", tuple(y.shape))
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_inner_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))


# 3. Same fallback pattern inside nn.Module.forward.
class SafeSoftmaxModule(torch.nn.Module):
    def forward(self, t):
        try:
            return torch.softmax(t, dim=10)
        except IndexError:
            return torch.softmax(t, dim=-1)


print("\n===== nn_module_forward_fallback =====")
eager_exc = None
compile_exc = None
try:
    m = SafeSoftmaxModule().to(device)
    y = m(x)
    sync_if_needed()
    print("eager_exception= None")
    print("eager_ok=", tuple(y.shape))
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc, str(e).splitlines()[0])

try:
    torch._dynamo.reset()
    cm = torch.compile(m, backend="inductor")
    y = cm(x)
    sync_if_needed()
    print("compile_exception= None")
    print("compile_ok=", tuple(y.shape))
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc, str(e).splitlines()[0])

print("VERDICT=" + ("FAIL_nn_module_try_except_broken" if eager_exc is None and compile_exc is not None else "PASS"))


# 4. Catch IndexError and re-raise a custom application exception.
class MyError(Exception):
    pass


def reraise_custom(t):
    try:
        return torch.softmax(t, dim=10)
    except IndexError as e:
        raise MyError("softmax fallback failed") from e


print("\n===== reraise_custom_exception =====")
eager_exc = None
compile_exc = None
try:
    reraise_custom(x)
    sync_if_needed()
    print("eager_exception= None")
except Exception as e:
    eager_exc = type(e).__name__
    print("eager_exception=", eager_exc)

try:
    torch._dynamo.reset()
    cr = torch.compile(reraise_custom, backend="inductor")
    cr(x)
    sync_if_needed()
    print("compile_exception= None")
except Exception as e:
    compile_exc = type(e).__name__
    print("compile_exception=", compile_exc)

print("VERDICT=" + ("FAIL_reraise_contract_broken" if eager_exc == "MyError" and compile_exc != "MyError" else "PASS"))

Actual behavior

Verified on Colab:

PyTorch: 2.10.0+cu128
CUDA: 12.8
GPU: Tesla T4
contract_failures: 4
passes: 0
errors_or_timeouts: 0

Observed output summary:

===== minimal_outer_except_indexerror =====
returncode: 0
--- stdout ---
device= cuda
eager= ('caught_IndexError', 'IndexError', 'Dimension out of range (expected to be in range of [-1, 0], but got 10)')
compile= ('caught_other', 'BackendCompilerFailed', "backend='inductor' raised:")
VERDICT=FAIL_contract_broken

===== inner_safe_softmax_fallback =====
returncode: 0
--- stdout ---
compile_exception= BackendCompilerFailed backend='inductor' raised:
VERDICT=FAIL_inner_try_except_broken

===== nn_module_forward_fallback =====
returncode: 0
--- stdout ---
compile_exception= BackendCompilerFailed backend='inductor' raised:
VERDICT=FAIL_nn_module_try_except_broken

===== reraise_custom_exception =====
returncode: 0
--- stdout ---
eager_exception= MyError
compile_exception= BackendCompilerFailed
VERDICT=FAIL_reraise_contract_broken

=== SUMMARY ===
contract_failures: 4
passes: 0
errors_or_timeouts: 0
expected_if_bug_reproduces: contract_failures >= 1 and errors_or_timeouts == 0
SystemExit: 0

The stderr contains internal fake tensor/meta logs showing the underlying error source, e.g. torch/_prims_common/__init__.py, canonicalize_dim, raise IndexError(msg). In the softmax case the log mentions aten.amax.default, which appears to be an internal lowering/reduction detail, not the user-facing operation being reported.

Additional batch evidence

A separate batch of 9 IndexError-contract repros shows the same broad pattern across multiple dimension/index APIs. In each case, eager raises IndexError, but compiled mode exposes a non-IndexError exception, so except IndexError does not catch it.

softmax_dim_oor    eager: IndexError -> compile: BackendCompilerFailed

tensor_size_neg    eager: IndexError -> compile: InternalTorchDynamoError
squeeze_oor_dim    eager: IndexError -> compile: TorchRuntimeError
unsqueeze_oor_dim  eager: IndexError -> compile: TorchRuntimeError
amax_oor_dim       eager: IndexError -> compile: TorchRuntimeError
argmax_oor_dim     eager: IndexError -> compile: TorchRuntimeError
flip_oor_dim       eager: IndexError -> compile: TorchRuntimeError
cat_oor_dim        eager: IndexError -> compile: TorchRuntimeError
stack_oor_dim      eager: IndexError -> compile: TorchRuntimeError

Result: 4/4 safe confirmation cases failed; a separate 9-case verifier also reproduced 9/9 original IndexError cases.

This is why the issue title/body uses “non-IndexError exceptions” rather than only BackendCompilerFailed.

Expected behavior

Compiled mode should preserve the Python exception contract for user code. In particular, when eager execution raises IndexError, compiled execution should either:

  1. raise IndexError or a subclass of IndexError, so except IndexError: continues to work; or
  2. otherwise preserve behavior so that the original IndexError reaches user handlers.

It should not expose a non-IndexError wrapper such as BackendCompilerFailed, TorchRuntimeError, or InternalTorchDynamoError for user-level tensor bounds errors in a way that bypasses application fallback code.

Why this matters

This can break real application safety/fallback code. A common pattern is:

def safe_apply(fn, x, dim):
    try:
        return fn(x, dim=dim)
    except IndexError:
        return fn(x, dim=-1)

In eager mode this pattern works. Under torch.compile, the fallback path can be skipped and the application sees a non-IndexError compiler/Dynamo wrapper instead.

Related issues / prior art

I found related issues, but I did not find an exact duplicate focused on preserving the typed IndexError contract across eager vs compiled execution:

  • https://github.com/pytorch/pytorch/issues/174166 — this is the nearest general parent issue I found. It reports that torch.compile does not support Python try/except and that an AttributeError fallback in nn.Module.forward fails. The repro here specifically covers tensor dimension IndexError becoming non-IndexError wrappers and breaking except IndexError contract patterns.
  • https://github.com/pytorch/pytorch/issues/153605 — related try/except behavior around AttributeError under compiled execution.
  • https://github.com/pytorch/pytorch/issues/95277 — related to dimension-out-of-range being wrapped as BackendCompilerFailed, but it does not cover application-level except IndexError fallback semantics.
  • https://github.com/pytorch/pytorch/issues/167900 — related to exception handling / traceback fidelity in torch.compile, but focuses on traceback/exception-handler fidelity rather than tensor bounds IndexError being converted into a non-IndexError user-visible exception.

Environment

Python: 3.12.13
PyTorch: 2.10.0+cu128
CUDA: 12.8
GPU: Tesla T4
Source: Colab re-test
Result: 4/4 safe confirmation cases failed the exception-contract check; 9/9 original IndexError cases also reproduced.

Original fuzzing environment

PyTorch: 2.9.1+cu128
CUDA: 12.8
GPU: NVIDIA L4 24GB
Source: original fuzzing artifact / generated findings
Result: same IndexError contract pattern observed across the 9 original cases.

Notes

SystemExit: 0 in the notebook output is expected because the verifier exits with success when the bug reproduces. The long stderr is not necessary to reproduce the issue; the important part is the stdout summary plus the short note that fake tensor/meta logs show an underlying IndexError.

cc @chauhang @penguinwu

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Compiled mode should preserve the Python exception contract for user code. In particular, when eager execution raises IndexError, compiled execution should either:

  1. raise IndexError or a subclass of IndexError, so except IndexError: continues to work; or
  2. otherwise preserve behavior so that the original IndexError reaches user handlers.

It should not expose a non-IndexError wrapper such as BackendCompilerFailed, TorchRuntimeError, or InternalTorchDynamoError for user-level tensor bounds errors in a way that bypasses application fallback code.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING