pytorch - 💡(How to fix) Fix [MPS] Col2Im.metal silent uint32 overflow in corrupts large outputs on MPS [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fixed

Code Example

uint col_index = dim1_idx * col_channel_stride + dim2_idx * col_spatial_stride;

---

import torch

KH, KW = 1, 65537
H_OUT, W_OUT = 1, 65537
H_IN, W_IN = 1, 131073           # so W_OUT = W_IN - KW + 1
K = KH * KW                      # 65537 (with C_in=1)
L = H_OUT * W_OUT                # 65537
# K * L = 65537 * 65537 = 2^32 + 131074  -> uint32 overflow inside col2im

torch.manual_seed(0)
cols = torch.randn(1, K, L)

out_cpu = torch.ops.aten.col2im(cols, [H_IN, W_IN], [KH, KW], [1, 1], [0, 0], [1, 1])

print("Running aten::col2im on MPS", flush=True)
out_mps = torch.ops.aten.col2im(
    cols.to("mps"), [H_IN, W_IN], [KH, KW], [1, 1], [0, 0], [1, 1]
).cpu()
torch.mps.synchronize()

print()
print(f"{'w':>10}  {'cpu':>14}  {'mps':>14}  {'abs_err':>14}")
for i in [0, 32768, 65535, 65536, 65537, 65538, 98304, 131072]:
    c = out_cpu.flatten()[i].item()
    m = out_mps.flatten()[i].item()
    print(f"{i:>10}  {c:>14.4f}  {m:>14.4f}  {abs(c - m):>14.4f}")

left_err = (out_cpu[..., :65536] - out_mps[..., :65536]).abs().max().item()
right_err = (out_cpu[..., 65536:] - out_mps[..., 65536:]).abs().max().item()
print(f"positions [0, 65535]      max abs err = {left_err:.4e}")
print(f"positions [65536, 131072] max abs err = {right_err:.4e}")

---

w             cpu             mps         abs_err
           0         -1.1258         -1.1258          0.0000
       32768        214.7137        214.7132          0.0005
       65535       -256.6318       -256.6317          0.0000
       65536        -33.2780        -34.9529          1.6750
       65537       -300.6619       -302.2877          1.6257
       65538       -298.3085       -297.8763          0.4323
       98304       -224.1576       -222.6533          1.5043
      131072          1.0863          1.5554          0.4691

  positions [0, 65535]      max abs err = 1.0498e-02
  positions [65536, 131072] max abs err = 8.8287e+00
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

at::col2im on MPS produces incorrect results when K * L > 2³², where K = C·KH·KW and L = H_out·W_out are the channel and spatial extents of the input col tensor. The corrupted positions are those whose computed col_index would exceed 2³².

The problem happens at

  uint col_index = dim1_idx * col_channel_stride + dim2_idx * col_spatial_stride;

Opening a PR to move ulong for the relevant variables. This also mirrors what the op counterpart at::im2col already does.

Repro script:

import torch

KH, KW = 1, 65537
H_OUT, W_OUT = 1, 65537
H_IN, W_IN = 1, 131073           # so W_OUT = W_IN - KW + 1
K = KH * KW                      # 65537 (with C_in=1)
L = H_OUT * W_OUT                # 65537
# K * L = 65537 * 65537 = 2^32 + 131074  -> uint32 overflow inside col2im

torch.manual_seed(0)
cols = torch.randn(1, K, L)

out_cpu = torch.ops.aten.col2im(cols, [H_IN, W_IN], [KH, KW], [1, 1], [0, 0], [1, 1])

print("Running aten::col2im on MPS", flush=True)
out_mps = torch.ops.aten.col2im(
    cols.to("mps"), [H_IN, W_IN], [KH, KW], [1, 1], [0, 0], [1, 1]
).cpu()
torch.mps.synchronize()

print()
print(f"{'w':>10}  {'cpu':>14}  {'mps':>14}  {'abs_err':>14}")
for i in [0, 32768, 65535, 65536, 65537, 65538, 98304, 131072]:
    c = out_cpu.flatten()[i].item()
    m = out_mps.flatten()[i].item()
    print(f"{i:>10}  {c:>14.4f}  {m:>14.4f}  {abs(c - m):>14.4f}")

left_err = (out_cpu[..., :65536] - out_mps[..., :65536]).abs().max().item()
right_err = (out_cpu[..., 65536:] - out_mps[..., 65536:]).abs().max().item()
print(f"positions [0, 65535]      max abs err = {left_err:.4e}")
print(f"positions [65536, 131072] max abs err = {right_err:.4e}")

Printout

           w             cpu             mps         abs_err
           0         -1.1258         -1.1258          0.0000
       32768        214.7137        214.7132          0.0005
       65535       -256.6318       -256.6317          0.0000
       65536        -33.2780        -34.9529          1.6750
       65537       -300.6619       -302.2877          1.6257
       65538       -298.3085       -297.8763          0.4323
       98304       -224.1576       -222.6533          1.5043
      131072          1.0863          1.5554          0.4691

  positions [0, 65535]      max abs err = 1.0498e-02
  positions [65536, 131072] max abs err = 8.8287e+00

Versions

Nightly, M2 Max, MacOS 26.4

cc @kulinseth @malfet @DenisVieriu97 @aditvenk

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [MPS] Col2Im.metal silent uint32 overflow in corrupts large outputs on MPS [1 pull requests]