pytorch - 💡(How to fix) Fix `torch.cumprod` / `torch.cumsum` / `torch.sort` kill the Python process with SIGABRT on MPS for tensors with `ndim > 4` and `dim < ndim - 4`

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

RuntimeError: MPS supports at most 4 dimensions for this operation (or similar), raised at the Python boundary before dispatch.

Fix Action

Fix / Workaround

PyTorch does not validate this constraint before dispatching to MPS, so the Metal runtime assertion fires and kills the host process.

RuntimeError: MPS supports at most 4 dimensions for this operation (or similar),
raised at the Python boundary before dispatch.

torch.cumprod(5D_tensor, dim=0) is a fully valid call — PyTorch tensors can have any number of dimensions and any valid dim. There is no documented restriction that MPS cumprod/cumsum/sort only support ndim ≤ 4. A SIGABRT is never an acceptable outcome: it kills the Python process, prevents cleanup, and cannot be caught with try/except. The correct behavior is to raise a RuntimeError at the Python boundary (before MPS dispatch) or fall back to CPU automatically.

Code Example

MPSNDArrayScan.mm:251:  failed assertion `(null)'  Axis = 4.
  This class only supports axis = 0, 1, 2, 3
  (torch.cumprod, torch.cumsum)

MPSNDArraySort.mm:252:  failed assertion `(null)'  Axis = 4.
  This class only supports axis = 0, 1, 2, 3
  (torch.sort)

---

import torch

assert torch.backends.mps.is_available(), "MPS not available"
device = torch.device("mps")

# Safe: ndim=4 — always ok
x4 = torch.ones(2, 2, 2, 2, device=device)
print(torch.cumprod(x4, dim=0).shape)   # ok

# Safe: ndim=5 but dim=1 (axis from right = 3) — ok
x5 = torch.ones(2, 2, 2, 2, 2, device=device)
print(torch.cumprod(x5, dim=1).shape)   # ok

# CRASH: ndim=5, dim=0 (axis from right = 4) — kills process
x5 = torch.ones(2, 2, 2, 2, 2, device=device)
torch.cumprod(x5, dim=0)   # SIGABRT — cannot be caught with try/except

---

# Process killed: Assertion failed: (null), function MPSNDArrayScan::initWithDevice,
# file MPSNDArrayScan.mm, line 251.
# Abort trap: 6

---

RuntimeError: MPS supports at most 4 dimensions for this operation (or similar),
raised at the Python boundary before dispatch.

---

PyTorch version: 2.13.0.dev20260512+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-37-generic-x86_64-with-glibc2.39
Is CUDA available: True
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version: 590.48.01

[pip3] numpy==2.4.4
[pip3] torch==2.13.0.dev20260512+cu130
[pip3] triton==3.7.0+git88b227e2
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.cumprod, torch.cumsum, and torch.sort crash the Python process with SIGABRT (exit code -6) on MPS when given a tensor with more than 4 dimensions and a dim value that maps to an axis index ≥ 4 from the right. No Python exception is raised — the process is killed by a fatal Metal assertion and cannot be caught with try/except.

The crash rule is: dim < ndim - 4 (i.e. the requested axis is not in the rightmost 4 dimensions).

Crash origins in the Metal runtime:

MPSNDArrayScan.mm:251:  failed assertion `(null)'  Axis = 4.
  This class only supports axis = 0, 1, 2, 3
  (torch.cumprod, torch.cumsum)

MPSNDArraySort.mm:252:  failed assertion `(null)'  Axis = 4.
  This class only supports axis = 0, 1, 2, 3
  (torch.sort)

PyTorch does not validate this constraint before dispatching to MPS, so the Metal runtime assertion fires and kills the host process.

Minimal reproducer

import torch

assert torch.backends.mps.is_available(), "MPS not available"
device = torch.device("mps")

# Safe: ndim=4 — always ok
x4 = torch.ones(2, 2, 2, 2, device=device)
print(torch.cumprod(x4, dim=0).shape)   # ok

# Safe: ndim=5 but dim=1 (axis from right = 3) — ok
x5 = torch.ones(2, 2, 2, 2, 2, device=device)
print(torch.cumprod(x5, dim=1).shape)   # ok

# CRASH: ndim=5, dim=0 (axis from right = 4) — kills process
x5 = torch.ones(2, 2, 2, 2, 2, device=device)
torch.cumprod(x5, dim=0)   # SIGABRT — cannot be caught with try/except

Observed output

# Process killed: Assertion failed: (null), function MPSNDArrayScan::initWithDevice,
# file MPSNDArrayScan.mm, line 251.
# Abort trap: 6

Expected output

RuntimeError: MPS supports at most 4 dimensions for this operation (or similar),
raised at the Python boundary before dispatch.

Why this is a bug

torch.cumprod(5D_tensor, dim=0) is a fully valid call — PyTorch tensors can have any number of dimensions and any valid dim. There is no documented restriction that MPS cumprod/cumsum/sort only support ndim ≤ 4. A SIGABRT is never an acceptable outcome: it kills the Python process, prevents cleanup, and cannot be caught with try/except. The correct behavior is to raise a RuntimeError at the Python boundary (before MPS dispatch) or fall back to CPU automatically.

Exhaustive failing combinations (all with torch.ones on MPS):

OpndimdimResult
cumprod / cumsum50CRASH (axis from right = 4)
cumprod / cumsum60CRASH (axis from right = 5)
cumprod / cumsum61CRASH (axis from right = 4)
sort50CRASH (axis from right = 4)
cumprod / cumsum51–4ok
any op≤ 4anyok

torch.argsort is also likely affected via MPSNDArraySort.

Versions

PyTorch version: 2.13.0.dev20260512+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-37-generic-x86_64-with-glibc2.39
Is CUDA available: True
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version: 590.48.01

[pip3] numpy==2.4.4
[pip3] torch==2.13.0.dev20260512+cu130
[pip3] triton==3.7.0+git88b227e2

cc @malfet @kulinseth @DenisVieriu97 @jhavukainen @aditvenk

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.cumprod` / `torch.cumsum` / `torch.sort` kill the Python process with SIGABRT on MPS for tensors with `ndim > 4` and `dim < ndim - 4`