pytorch - ✅(Solved) Fix [MPS] Replace isin MPSGraph implementation with native Metal kernel and binary-search path for large inputs [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fix / Workaround

The current MPS implementation of aten::isin.Tensor_Tensor_out is built on MPSGraph, which adds substantial per-call overhead compared to dispatching a Metal kernel directly.

Rewrite the op as a Metal kernel implementing the same O(ne × nt) membership scan. This drops MPSGraph dispatch overhead, which is what dominates runtime for small inputs, without changing the underlying algorithm.

  • The concatenate-and-sort approach used by CPU/CUDA, which concatenates elements and test_elements into a single array, sorts the combined array, and identifies matches via adjacent duplicate detection, running in O((ne + nt) × log(ne + nt)). This is worse than our O((ne + nt) × log nt) approach.
  • Using at::searchsorted in place of a custom binary-search kernel. Functionally equivalent, but the binary search itself is O(log nt) work per element, making the additional dispatch overhead from at::searchsorted and following operations disproportionately costly.

PR fix notes

PR #180353: [MPS] Replace isin MPSGraph implementation with native Metal kernel

Description (problem / solution / changelog)

Part 1 of #180833

Summary

Replaces the MPSGraph-based isin implementation with a native Metal kernel.

The previous implementation broadcast-expanded both input tensors into a 2D equality matrix of shape (ne, nt) via MPSGraph, then applied an OR reduction along the test dimension. For large inputs this allocates an intermediate tensor of ne x nt elements, which becomes the primary bottleneck. This PR replaces it with a single Metal kernel where each thread scans all test_elements for one element, writing one bool directly to the output with no intermediate allocation.

Implementation

The kernel is a templated Metal shader in TensorCompare.metal. Each thread loads one element and scans test_elements in a branch-free |= loop, using uint32_t for the loop index to avoid 64-bit integer overhead on Apple Silicon. The result is written as found != invert to handle the invert flag without a branch.

Registered for: float, half, bfloat16, int, long, short, char, uchar, bool.

Edge cases handled:

  • Empty elements: early return
  • Empty test_elements: fills output with invert value directly, no kernel dispatch
  • Non-contiguous inputs: copied to contiguous before dispatch, result copied back if out is non-contiguous
  • Unsupported dtypes (e.g. complex, uint16/32/64): guarded with TORCH_CHECK

Benchmark

Benchmarked on Apple M5 (24GB), float32, eager mode. Speedups are averaged across test set sizes of 4, 32, 256, and 4096.

Results by input size:

  • 1K elements: 0.9x average
  • 8K elements: 1.3x average
  • 65K elements: 2.0x average
  • 500K elements: 2.9x average
  • 1M elements: 3.3x average, peaking at 4.5x for large test sets

Results are consistent across float16 and bfloat16.

Tests

All existing isin MPS tests pass. A benchmark entry (bench_isin) has been added to test/bench_mps_ops.py, sweeping elements sizes from 1K to 1M against test set sizes of 4, 32, 256, and 4096.

Changed files

  • aten/src/ATen/native/mps/kernels/TensorCompare.metal (modified, +35/-0)
  • aten/src/ATen/native/mps/operations/TensorCompare.mm (modified, +32/-45)
  • test/bench_mps_ops.py (modified, +67/-0)
RAW_BUFFERClick to expand / collapse

Motivation

The current MPS implementation of aten::isin.Tensor_Tensor_out is built on MPSGraph, which adds substantial per-call overhead compared to dispatching a Metal kernel directly.

This issue proposes a two-part rewrite of isin_Tensor_Tensor_out_mps, both targeting performance.

Part 1: Replace the MPSGraph implementation with a Metal kernel

Rewrite the op as a Metal kernel implementing the same O(ne × nt) membership scan. This drops MPSGraph dispatch overhead, which is what dominates runtime for small inputs, without changing the underlying algorithm.

Part 2: Add a binary-search fast path for large inputs

For large inputs the O(ne × nt) scan is the bottleneck. Sorting test_elements once with at::sort and then launching a second kernel that binary-searches it per element drops complexity to O((ne + nt) × log nt), which wins once ne × nt exceeds an empirically-calibrated crossover point.

Two alternatives were considered:

  • The concatenate-and-sort approach used by CPU/CUDA, which concatenates elements and test_elements into a single array, sorts the combined array, and identifies matches via adjacent duplicate detection, running in O((ne + nt) × log(ne + nt)). This is worse than our O((ne + nt) × log nt) approach.
  • Using at::searchsorted in place of a custom binary-search kernel. Functionally equivalent, but the binary search itself is O(log nt) work per element, making the additional dispatch overhead from at::searchsorted and following operations disproportionately costly.

cc @jerryzh168 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

Rewrite the aten::isin.Tensor_Tensor_out implementation to use a Metal kernel for small inputs and add a binary-search fast path for large inputs to improve performance.

Guidance

  • Replace the existing MPSGraph implementation with a Metal kernel to reduce per-call overhead for small inputs.
  • Implement a binary-search fast path for large inputs by sorting test_elements and launching a second kernel to binary-search it per element, reducing complexity to O((ne + nt) × log nt).
  • Consider the trade-offs between different approaches, such as the concatenate-and-sort method and using at::searchsorted, and choose the most efficient one based on empirical calibration.
  • Calibrate the crossover point for switching between the O(ne × nt) scan and the binary-search fast path to ensure optimal performance for different input sizes.

Example

No code snippet is provided as the issue does not contain sufficient implementation details.

Notes

The proposed solution assumes that the Metal kernel and binary-search fast path can be efficiently implemented and integrated into the existing codebase. The performance benefits of this approach may vary depending on the specific use case and input characteristics.

Recommendation

Apply the proposed two-part rewrite, as it is expected to provide significant performance improvements for both small and large inputs. The Metal kernel will reduce overhead for small inputs, while the binary-search fast path will improve performance for large inputs.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS] Replace isin MPSGraph implementation with native Metal kernel and binary-search path for large inputs [1 pull requests]