pytorch - ✅(Solved) Fix [MPS] Correctness issues in `count_nonzero`, `mean`, `nansum`, `sum`, and `trace` [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178497Fetched 2026-04-08 01:35:56
View on GitHub
Comments
1
Participants
1
Timeline
51
Reactions
0
Author
Participants
Timeline (top)
mentioned ×19subscribed ×19labeled ×7cross-referenced ×4

Fix Action

Fixed

PR fix notes

PR #178502: [MPS] Migrate count_nonzero from MPS to native implementation

Description (problem / solution / changelog)

Related to #178497

Changed files

  • aten/src/ATen/native/TensorAdvancedIndexing.cpp (modified, +1/-1)
  • aten/src/ATen/native/mps/operations/ReduceOps.mm (modified, +0/-38)
  • aten/src/ATen/native/native_functions.yaml (modified, +1/-2)

Code Example

import torch

torch.manual_seed(42)

g = torch.Generator(device="mps")
max_iters = 300000

x = torch.rand(16, 2100, 256, device="mps", dtype=torch.float32, generator=g)

x_trace = torch.rand(20000, 20000, device="mps", dtype=torch.float32, generator=g)
x_nansum_flat = x.flatten().clone()
x_nansum_flat[0::100] = float("nan")
x_nansum_flat = x_nansum_flat.reshape(16, 2100, 256)


def test_op(name, op):
    expected = op()
    print(f"\nTesting {name}...")
    for i in range(max_iters):
        result = op()
        if result != expected:
            print(f"FAIL: {name}")
            print(f"  expected: {expected}")
            print(f"  actual: {result}")
            print(f"  failed at iter {i}")
            return
    print(f"  {name} passed all {max_iters} iterations")


print(f"{max_iters=}")

test_op("count_nonzero", lambda: x.count_nonzero())
test_op("mean", lambda: x.mean())
test_op("nansum", lambda: x_nansum_flat.nansum())
test_op("sum", lambda: x.sum())
test_op("trace", lambda: x_trace.trace())

# Output:
# max_iters=300000
#
# Testing count_nonzero...
# FAIL: count_nonzero
#   expected: 8601599
#   actual: 12894206
#   failed at iter 18939
#
# Testing mean...
# FAIL: mean
#   expected: 0.49992606043815613
#   actual: 0.7494960427284241
#   failed at iter 94952
#
# Testing nansum...
# FAIL: nansum
#   expected: 4257176.0
#   actual: 6382371.0
#   failed at iter 115
#
# Testing sum...
# FAIL: sum
#   expected: 4300164.0
#   actual: 6446865.0
#   failed at iter 12956
#
# Testing trace...
# FAIL: trace
#   expected: 10091.3310546875
#   actual: 19146.9375
#   failed at iter 3977
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Several reduce ops are affected by silent correctness issues when ran in rapid succession. This includes count_nonzero, mean, nansum, sum, and trace.

amax, amin, max, and min have been tested, but did not display the behavior.

These issues were discovered while working on a solution for count_nonzero in #178079. The issues originate in the Apple MPS framework. The proposed solution is to rewrite the ops as Metal kernels.

MRE

import torch

torch.manual_seed(42)

g = torch.Generator(device="mps")
max_iters = 300000

x = torch.rand(16, 2100, 256, device="mps", dtype=torch.float32, generator=g)

x_trace = torch.rand(20000, 20000, device="mps", dtype=torch.float32, generator=g)
x_nansum_flat = x.flatten().clone()
x_nansum_flat[0::100] = float("nan")
x_nansum_flat = x_nansum_flat.reshape(16, 2100, 256)


def test_op(name, op):
    expected = op()
    print(f"\nTesting {name}...")
    for i in range(max_iters):
        result = op()
        if result != expected:
            print(f"FAIL: {name}")
            print(f"  expected: {expected}")
            print(f"  actual: {result}")
            print(f"  failed at iter {i}")
            return
    print(f"  {name} passed all {max_iters} iterations")


print(f"{max_iters=}")

test_op("count_nonzero", lambda: x.count_nonzero())
test_op("mean", lambda: x.mean())
test_op("nansum", lambda: x_nansum_flat.nansum())
test_op("sum", lambda: x.sum())
test_op("trace", lambda: x_trace.trace())

# Output:
# max_iters=300000
#
# Testing count_nonzero...
# FAIL: count_nonzero
#   expected: 8601599
#   actual: 12894206
#   failed at iter 18939
#
# Testing mean...
# FAIL: mean
#   expected: 0.49992606043815613
#   actual: 0.7494960427284241
#   failed at iter 94952
#
# Testing nansum...
# FAIL: nansum
#   expected: 4257176.0
#   actual: 6382371.0
#   failed at iter 115
#
# Testing sum...
# FAIL: sum
#   expected: 4300164.0
#   actual: 6446865.0
#   failed at iter 12956
#
# Testing trace...
# FAIL: trace
#   expected: 10091.3310546875
#   actual: 19146.9375
#   failed at iter 3977

Versions

PyTorch version: 2.12.0a0+git61ce48e Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.4 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: version 4.1.2 Libc version: N/A

Python version: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:19:53) [Clang 18.1.8 ] (64-bit runtime) Python platform: macOS-26.4-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M3 Max

Versions of relevant libraries: [pip3] flake8==7.2.0 [pip3] mypy==1.13.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==2.4.3 [pip3] onnx==1.17.0 [pip3] onnx2torch==1.5.15 [pip3] onnxruntime==1.21.1 [pip3] onnxscript==0.1.0.dev20240817 [pip3] optree==0.13.0 [pip3] pytorch_sphinx_theme==0.0.24 [pip3] torch==2.12.0a0+git61ce48e [pip3] torchaudio==2.6.0a0+1a8f621 [pip3] torchbench==0.1 [pip3] torchvision==0.26.0a0+6285457 [conda] numpy 2.4.3 pypi_0 pypi [conda] onnx2torch 1.5.15 pypi_0 pypi [conda] optree 0.13.0 pypi_0 pypi [conda] pytorch-sphinx-theme 0.0.24 pypi_0 pypi [conda] torch 2.12.0a0+git61ce48e pypi_0 pypi [conda] torchaudio 2.6.0a0+1a8f621 dev_0 <develop> [conda] torchbench 0.1 dev_0 <develop> [conda] torchvision 0.26.0a0+6285457 dev_0 <develop>

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

To address the silent correctness issues with reduce operations on the Apple MPS framework, we will rewrite the affected operations as Metal kernels. Here are the steps:

  • Identify the affected operations: count_nonzero, mean, nansum, sum, and trace.
  • Create Metal kernels for each operation:
    • count_nonzero: Use a Metal kernel to count non-zero elements in the tensor.
    • mean: Use a Metal kernel to calculate the mean of the tensor.
    • nansum: Use a Metal kernel to calculate the sum of the tensor, ignoring NaN values.
    • sum: Use a Metal kernel to calculate the sum of the tensor.
    • trace: Use a Metal kernel to calculate the trace of the tensor.

Example code for count_nonzero Metal kernel:

import torch

class CountNonzeroMetalKernel(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # Create a Metal kernel to count non-zero elements
        kernel = """
            #include <metal_stdlib>
            using namespace metal;

            kernel void count_nonzero(global float *x, global int *count, uint idx) {
                int local_count = 0;
                for (uint i = 0; i < 256; i++) {
                    if (x[idx * 256 + i] != 0) {
                        local_count++;
                    }
                }
                *count = local_count;
            }
        """
        # Compile the Metal kernel
        module = torch.metal.compile(kernel)
        # Launch the Metal kernel
        count = torch.zeros(1, dtype=torch.int32, device=x.device)
        module.launch(kernel="count_nonzero", inputs=[x, count], grid=(1, 1), block=(1, 1))
        return count

    @staticmethod
    def backward(ctx, grad):
        # Implement the backward pass
        pass

# Example usage:
x = torch.rand(16, 2100, 256, device="mps", dtype=torch.float32)
count = CountNonzeroMetalKernel.apply(x)
print(count)

Verification

To verify that the fix worked, run the test cases again and check that the results are correct.

Extra Tips

  • Make sure to test the Metal kernels thoroughly to ensure they produce the correct results.
  • Consider adding additional error checking and handling to the Metal kernels to handle edge cases.
  • If you encounter any issues with the Metal kernels, try debugging them using the Metal debugger or by adding print statements to the kernel code.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING