pytorch - ✅(Solved) Fix [MPS] `sum` uses saturated cast [1 pull requests, 3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179415Fetched 2026-04-08 02:51:53
View on GitHub
Comments
3
Participants
2
Timeline
34
Reactions
0
Author
Participants
Timeline (top)
mentioned ×10subscribed ×10labeled ×6commented ×3

Fix Action

Fixed

PR fix notes

PR #179407: [MPS] Enable test_reductions.py with skips

Description (problem / solution / changelog)

This PR adds @skipIfMPS decorators to test methods that fail on MPS due to unsupported dtypes (float64, complex128), unimplemented operators (hash_tensor, mode), or known issues (view size incompatibility, dimension limits, accuracy mismatches with complex types).

Follow-up work should remove these skips and instead use expectedFailure in test_reductions.py (eg. 6aeb1dc624b844fd7234fcdaf6e3f4b0d91d2ff5) or torch/testing/_internal/common_methods_invocations.py (eg. 72754c07f03382b1a7f304628991a26c7c79d534) for the tests that use OpInfo.

Related to #178497

Changed files

  • test/test_reductions.py (modified, +79/-3)
  • torch/testing/_internal/common_methods_invocations.py (modified, +2/-0)

Code Example

import torch

example = [[-1, 2, 1], [5, 3, 6]]
x_mps = torch.tensor(example, dtype=torch.uint8, device="mps")
x_cpu = torch.tensor(example, dtype=torch.uint8, device="cpu")

y_cpu = x_cpu.sum(dtype=torch.uint8).item()
y_mps = x_mps.sum(dtype=torch.uint8).item()

print(f"{y_cpu=}")
print(f"{y_mps=}")
# Output:
# y_cpu=16
# y_mps=255
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Discovered while working on #179407 and #178497.

import torch

example = [[-1, 2, 1], [5, 3, 6]]
x_mps = torch.tensor(example, dtype=torch.uint8, device="mps")
x_cpu = torch.tensor(example, dtype=torch.uint8, device="cpu")

y_cpu = x_cpu.sum(dtype=torch.uint8).item()
y_mps = x_mps.sum(dtype=torch.uint8).item()

print(f"{y_cpu=}")
print(f"{y_mps=}")
# Output:
# y_cpu=16
# y_mps=255

Versions

PyTorch version: 2.11.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.4 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: version 4.1.2 Libc version: N/A

Python version: 3.12.13 | packaged by conda-forge | (main, Mar 5 2026, 17:06:14) [Clang 19.1.7 ] (64-bit runtime) Python platform: macOS-26.4-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M3 Max

Versions of relevant libraries: [pip3] numpy==2.4.3 [pip3] torch==2.11.0 [pip3] torchvision==0.26.0 [conda] numpy 2.4.3 pypi_0 pypi [conda] torch 2.11.0 pypi_0 pypi [conda] torchvision 0.26.0 pypi_0 pypi

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

The issue can be fixed by changing the data type of the tensor from torch.uint8 to a larger integer type, such as torch.int32 or torch.int64, to prevent overflow when calculating the sum.

Guidance

  • The problem is likely caused by integer overflow when calculating the sum of the tensor elements, as the sum of the elements in the example exceeds the maximum value that can be represented by torch.uint8 (255).
  • To verify this, try printing the sum of the elements in the tensor without specifying the data type, and compare the result with the expected sum.
  • To mitigate this issue, change the data type of the tensor to a larger integer type, such as torch.int32 or torch.int64, when creating the tensor.
  • Additionally, consider using the torch.sum function without specifying the data type, and let PyTorch automatically determine the data type of the result.

Example

x_mps = torch.tensor(example, dtype=torch.int32, device="mps")
x_cpu = torch.tensor(example, dtype=torch.int32, device="cpu")

y_cpu = x_cpu.sum().item()
y_mps = x_mps.sum().item()

print(f"{y_cpu=}")
print(f"{y_mps=}")

Notes

The issue is specific to the torch.uint8 data type and the MPS device, and may not occur with other data types or devices.

Recommendation

Apply workaround by changing the data type of the tensor to a larger integer type, such as torch.int32 or torch.int64, to prevent overflow when calculating the sum. This is because the sum of the elements in the example exceeds the maximum value that can be represented by torch.uint8 (255).

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS] `sum` uses saturated cast [1 pull requests, 3 comments, 2 participants]