pytorch - ✅(Solved) Fix [MPS] Add support for multi-dimensional all / any reductions [1 pull requests, 4 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#176826Fetched 2026-04-08 00:24:18
View on GitHub
Comments
4
Participants
2
Timeline
60
Reactions
0
Participants
Timeline (top)
mentioned ×21subscribed ×21labeled ×7unsubscribed ×5

Fix Action

Fixed

PR fix notes

PR #176890: [MPS] Add native multi-dimensional reduction support for torch.all and torch.any

Description (problem / solution / changelog)

Fixes #176826

Summary

Adds native MPS support for multi-dimensional reductions in torch.all and torch.any when dim is provided as multiple dimensions.

Previously, this case relied on the composite fallback implementation, which performs sequential reductions across dimensions. This change performs the reduction across multiple axes directly using a single MPSGraph reduction operation.

While the existing implementation may internally use MPS kernels for each reduction step, there is no dedicated implementation for reducing across multiple dimensions in a single operation. This PR adds native support for that case.

The primary motivation of this change is to improve backend feature parity between MPS and other backends.

Implementation

The implementation introduces native MPS kernels for:

  • all_dims_out_mps
  • any_dims_out_mps

These kernels perform reductions across multiple axes using the corresponding MPSGraph operations.

Edge cases handled:

  • dim=None
  • dim=[]
  • empty tensors
  • keepdim=True and keepdim=False

The implementation falls back to the default behavior for unsupported cases.

Benchmark

Initial benchmarking was performed comparing this implementation against the previous composite reduction path.

In practice, performance differences between the two approaches were minimal, with results generally falling within normal run-to-run variance. In some cases the new implementation showed a slight improvement, but overall performance is largely comparable.

One potential advantage of the new implementation is reduced intermediate tensor creation, since the reduction can be expressed as a single operation rather than a sequence of reductions. In theory this could reduce memory overhead.

Given these observations, the primary benefit of this change is backend feature parity and a more direct mapping of multi-dimensional reductions to the underlying MPSGraph operations.

Tests

Added comprehensive test coverage for multi-dimensional all and any operations:

  • test_all_any_multi_dims: Tests various combinations of multi-dimensional reductions across different tensor shapes (2D, 3D, 4D) with different dimension combinations
  • test_all_any_dims_none_and_empty: Tests edge cases with dim=None and empty dimension lists
  • test_all_any_negative_dims: Tests negative dimension indexing for multi-dimensional reductions
  • test_all_any_duplicate_dims_error: Tests error handling for duplicate dimensions

All tests verify correctness by comparing MPS results against CPU reference implementations.

Notes

This change improves backend feature parity for the MPS backend while providing a native implementation for multi-dimensional logical reductions.

Changed files

  • aten/src/ATen/native/mps/operations/ReduceOps.mm (modified, +159/-0)
  • aten/src/ATen/native/native_functions.yaml (modified, +2/-0)
  • test/bench_mps_ops.py (modified, +27/-0)
  • test/test_mps.py (modified, +139/-2)

Code Example

x = torch.ones(2, 3, 4, device="mps", dtype=torch.bool)

torch.all(x, dim=[0, 1])
torch.any(x, dim=[1, 2])
RAW_BUFFERClick to expand / collapse

The MPS backend does not currently provide a dedicated native implementation for multi-dimensional reductions in torch.all and torch.any when dim is provided as multiple dimensions.

Example:

x = torch.ones(2, 3, 4, device="mps", dtype=torch.bool)

torch.all(x, dim=[0, 1])
torch.any(x, dim=[1, 2])

These operations currently work through the composite implementation, which performs reductions sequentially across dimensions. While each individual reduction may use an existing MPS kernel, there is no native MPS implementation that reduces across multiple dimensions in a single operation.

Adding a native implementation for this case would improve backend feature parity with CPU and CUDA and could allow the reduction to be performed more efficiently on MPS.

I have a working implementation with tests and would be happy to open a PR for this.

cc @jerryzh168 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

Native Implementation for Multi-Dimensional Reductions

Step 1: Create a new MPS kernel for multi-dimensional reductions

  • Create a new file multi_dim_reduction_kernel.c in the torch/csrc/MPSoC/kernels directory.
  • Define the kernel function multi_dim_reduction_kernel that takes the input tensor, dimensions, and operation type (all or any) as arguments.

Step 2: Implement the kernel function

#include <torch/extension.h>

__kernel void multi_dim_reduction_kernel(
    __global const bool* input,
    int num_elements,
    int num_dims,
    int* result,
    int op_type
) {
    int idx = get_global_id(0);
    int local_idx = idx;
    int local_num_elements = 1;

    for (int i = 0; i < num_dims; i++) {
        local_num_elements *= num_elements;
        local_idx = local_idx * num_elements + idx;
    }

    bool local_result = 1;
    for (int i = 0; i < local_num_elements; i++) {
        local_result = (op_type == 0) ? local_result && input[i] : local_result || input[i];
    }

    result[local_idx] = local_result;
}

Step 3: Register the kernel with PyTorch

  • Create a new file multi_dim_reduction_kernel.cpp in the torch/csrc/MPSoC/kernels directory.
  • Define the torch::register_kernel function to register the multi_dim_reduction_kernel kernel.
#include <torch/extension.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("multi_dim_reduction_kernel", &torch::register_kernel<torch::MPSoCDevice, multi_dim_reduction_kernel>);
}

Step 4: Update the

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING