pytorch - ✅(Solved) Fix [MPS] Remove Unnecessary Tensor Memory Gathers and Contiguous Calls for Stride-Aware MPS Backend Inputs [2 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181946Fetched 2026-04-30 06:17:41
View on GitHub
Comments
1
Participants
1
Timeline
34
Reactions
0
Participants
Timeline (top)
mentioned ×12subscribed ×12labeled ×6cross-referenced ×2

Fix Action

Fix / Workaround

When the MPS backend was first implemented, Metal kernels and MPSGraph dispatch were written assuming contiguous memory layout. To handle non-contiguous inputs correctly without rewriting every kernel, many ops were patched with a gatherViewTensor() or .contiguous() call at the entry point as a band-aid fix to ensure correctness at the cost of performance.

PR fix notes

PR #181949: [MPS] Enhance Col2Im tensor op to avoid .contiguous() call and make corresponding Metal kernel stride-aware

Description (problem / solution / changelog)

Col2Im Tensor OP Enhancement

  • Removing the .contiguous() call on col_tensor in col2im_out_mps_template() operation
  • Updating col2im_kernel() kernel to take in inner strides and manually calculate strided col_index

Benchmarked using following script

import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

# col2im_out_mps_template() calls input.contiguous() before PR #181949.
# Transpose dims 1 and 2 to make the input non-contiguous and force the copy.
x = torch.randn(32, 16384, 27, device="mps").transpose(1, 2)
assert not x.is_contiguous()

m = benchmark.Timer(
    stmt="F.fold(x, output_size=(128, 128), kernel_size=3, padding=1, stride=1)",
    globals={"F": F, "x": x},
    num_threads=1,
).blocked_autorange(min_run_time=10.0)

print(f"col2im: mean={m.mean * 1e6:.2f} us  median={m.median * 1e6:.2f} us  iqr={m.iqr * 1e6:.2f} us")

On an M3 Macbook Pro with 48GB memory we observe a 3.1x speedup.

Baseline (µs)Update (µs)Speedup
802.07258.653.1x

Issue #181946

Changed files

  • aten/src/ATen/native/mps/kernels/Col2Im.metal (modified, +7/-4)
  • aten/src/ATen/native/mps/operations/Col2Im.mm (modified, +6/-2)

PR #181951: [MPS] Enhance HistogramKernel tensor op to avoid unnecessary .contiguous() call

Description (problem / solution / changelog)

  • Removing the .contiguous() call for bin_edges and the is_contiguous() assertion in histogramdd_out_mps_template()
  • The tensor is never passed into kernel, it is only consumed on CPU through the .item() which already supports non-contiguous tensors.

Benchmarked using following script:

import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

bins_buf = torch.linspace(-4.0, 4.0, 514, device=self._device)
bins = bins_buf[::2]  # 257 edges → 256 bins, stride=2, non-contiguous
self.assertFalse(bins.is_contiguous())
x = torch.randn(1_000_000, device=self._device)

m = benchmark.Timer(
    stmt="torch.histogram(x, bins=bins)",
    globals={"torch": torch, "x": x, "bins": bins},
    num_threads=1,
).blocked_autorange(min_run_time=10.0)

print(f"histogramkernel: mean={m.mean * 1e6:.2f} us  median={m.median * 1e6:.2f} us  iqr={m.iqr * 1e6:.2f} us")

On an M3 Macbook Pro with 48GB memory we observe a 1.08x speedup.

Baseline (µs)Update (µs)Speedup
65242.7660402.981.08x

Issue #181946

Changed files

  • aten/src/ATen/native/mps/operations/HistogramKernel.mm (modified, +1/-7)
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

When the MPS backend was first implemented, Metal kernels and MPSGraph dispatch were written assuming contiguous memory layout. To handle non-contiguous inputs correctly without rewriting every kernel, many ops were patched with a gatherViewTensor() or .contiguous() call at the entry point as a band-aid fix to ensure correctness at the cost of performance.

Since then MPSGraph has been updated to support native strided tensors via useMPSStridedAPI (macOS 15+, with automatic gatherViewTensor fallback on older OS through the Placeholder constructor).

In addition, for custom Metal shaders in the MPS backend that still rely on the input tensors being contiguous, we can also consider updating them to be stride-aware by receiving explicit sizes and strides buffers and computing physical offsets manually, and examine if the the added kernel complexity has less overhead than the .contiguous() calls in the caller.

The following Tensor Ops in the MPS backend are using the band-aid .contiguous() fix and should be examined and or updated:

Tensor Op.contiguous() CallMetal Shader
Attention_scaled_dot_product_attention_math_mps()sdpa_vector_fast_mps()
Col2Imcol2im_out_mps_template()col2im_kernel()
HistogramKernelhistogramdd_out_mps_template()N/A
Im2Colim2col_out_mps_template()N/A
Indexingnonzero_impl_mps()count_nonzero_prefix_sum()
Linear_mps_linear_backward_input()N/A
LossOpshuber_loss_backward_out_mps()N/A
Poolingmps_max_pool2d_backward()N/A
RangeFactorieslinspace_out_mps()N/A

Alternatives

No response

Additional context

No response

cc @jerryzh168 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

Update the listed Tensor Ops in the MPS backend to use native strided tensors via useMPSStridedAPI or make custom Metal shaders stride-aware.

Guidance

  • Examine each Tensor Op using the .contiguous() fix and consider updating them to support native strided tensors.
  • For custom Metal shaders, update them to receive explicit sizes and strides buffers and compute physical offsets manually.
  • Evaluate the performance overhead of the added kernel complexity versus the .contiguous() calls.
  • Prioritize updating ops with high performance impact, such as Attention and Col2Im.

Example

No explicit code example is provided, but updating a custom Metal shader to be stride-aware might involve modifying the kernel function to accept sizes and strides buffers and calculate physical offsets.

Notes

The feasibility of updating each Tensor Op may vary, and some ops might require more significant changes than others. Additionally, the performance benefits of using native strided tensors will depend on the specific use case and hardware.

Recommendation

Apply workaround: Update the listed Tensor Ops to use native strided tensors or make custom Metal shaders stride-aware, as this approach allows for more efficient handling of non-contiguous inputs without requiring significant rewrites of existing kernels.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS] Remove Unnecessary Tensor Memory Gathers and Contiguous Calls for Stride-Aware MPS Backend Inputs [2 pull requests, 1 comments, 1 participants]