pytorch - ✅(Solved) Fix MPS autograd returns zero/corrupted batched input gradients for `torch.autograd.grad(y.sum(), x)` while `grad_outputs=torch.ones_like(y)` is correct [1 pull requests, 5 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180201Fetched 2026-04-15 06:19:36
View on GitHub
Comments
5
Participants
2
Timeline
107
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×40subscribed ×40referenced ×10labeled ×6

Root Cause

This was originally discovered because a neural SDF training pipeline on this machine became unstable on MPS while CUDA and another Apple MPS machine behaved correctly. After reducing it, the failure reproduces with raw tensor ops and no project-specific code.

Fix Action

Workaround

This avoids the issue for me:

torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))

instead of:

torch.autograd.grad(y.sum(), x)

The same workaround also fixes higher-order derivatives in my downstream code.

PR fix notes

PR #180236: [MPS] Fix mm with stride-0 inputs on macOS < 26.4

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #180236

MPSGraph's matrixMultiplication produces incorrect results when given stride-0 NDArray inputs (from tensor.expand()) on macOS >=15.0,<26.4. Only every 16th row of the output is computed correctly; the rest are zeroed.

Fix disables the MPS strided API specifically for mm inputs that have stride 0, causing them to be materialized contiguously via the gather/clone path before being passed to MPSGraph. The check is skipped on macOS >= 26.4 where underlying issue apparently has been fixed.

Fixes https://github.com/pytorch/pytorch/issues/180201

Co-authored-by: Claude [email protected]

Changed files

  • aten/src/ATen/mps/MPSDevice.h (modified, +1/-0)
  • aten/src/ATen/mps/MPSDevice.mm (modified, +3/-0)
  • aten/src/ATen/native/mps/OperationUtils.mm (modified, +0/-2)
  • aten/src/ATen/native/mps/operations/LinearAlgebra.mm (modified, +12/-4)
  • test/test_mps.py (modified, +10/-0)

Code Example

y = (torch.sin(x @ w1.T + b1) @ w2.T + b2).squeeze(-1)
g = torch.autograd.grad(y.sum(), x)[0]

---

g = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]

---

import json
import torch


def zero_rows(grad):
    return torch.where(grad.abs().sum(dim=1) == 0)[0].detach().cpu().tolist()


def build_case():
    dtype = torch.float32
    batch = 64
    hidden = 16

    torch.manual_seed(123)
    w1 = torch.randn(hidden, 3, dtype=dtype)
    b1 = torch.randn(hidden, dtype=dtype)
    w2 = torch.randn(1, hidden, dtype=dtype)
    b2 = torch.randn(1, dtype=dtype)

    torch.manual_seed(0)
    x = torch.randn(batch, 3, dtype=dtype)
    return x, w1, b1, w2, b2


def gradient(device, use_grad_outputs, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu):
    x = x_cpu.to(device).requires_grad_(True)
    w1 = w1_cpu.to(device)
    b1 = b1_cpu.to(device)
    w2 = w2_cpu.to(device)
    b2 = b2_cpu.to(device)
    y = (torch.sin(x @ w1.T + b1) @ w2.T + b2).squeeze(-1)

    if use_grad_outputs:
        grad = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]
    else:
        grad = torch.autograd.grad(y.sum(), x)[0]
    return grad.detach().cpu()


assert torch.backends.mps.is_available()

x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu = build_case()

cpu_sum = gradient("cpu", False, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
cpu_fix = gradient("cpu", True, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
mps_sum = gradient("mps", False, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
mps_fix = gradient("mps", True, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)

print(json.dumps({
    "torch_version": torch.__version__,
    "torch_git_version": getattr(torch.version, "git_version", None),
    "cpu_sum_zero_rows": zero_rows(cpu_sum),
    "cpu_fix_zero_rows": zero_rows(cpu_fix),
    "mps_sum_zero_rows": zero_rows(mps_sum),
    "mps_fix_zero_rows": zero_rows(mps_fix),
    "cpu_sum_vs_cpu_fix_mae": float((cpu_sum - cpu_fix).abs().mean()),
    "cpu_sum_vs_mps_sum_mae": float((cpu_sum - mps_sum).abs().mean()),
    "cpu_sum_vs_mps_fix_mae": float((cpu_sum - mps_fix).abs().mean()),
    "cpu_sum_vs_mps_sum_max_abs": float((cpu_sum - mps_sum).abs().max()),
    "cpu_sum_vs_mps_fix_max_abs": float((cpu_sum - mps_fix).abs().max()),
}, indent=2))

---

{
  "torch_version": "2.11.0",
  "torch_git_version": "70d99e998b4955e0049d13a98d77ae1b14db1f45",
  "cpu_sum_zero_rows": [],
  "cpu_fix_zero_rows": [],
  "mps_sum_zero_rows": [
    1, 2, 3, 4, 5, 6, 7, 8,
    9, 10, 11, 12, 13, 14, 15,
    17, 18, 19, 20, 21, 22, 23, 24,
    25, 26, 27, 28, 29, 30, 31,
    33, 34, 35, 36, 37, 38, 39, 40,
    41, 42, 43, 44, 45, 46, 47,
    49, 50, 51, 52, 53, 54, 55, 56,
    57, 58, 59, 60, 61, 62, 63
  ],
  "mps_fix_zero_rows": [],
  "cpu_sum_vs_cpu_fix_mae": 0.0,
  "cpu_sum_vs_mps_sum_mae": 1.7952971458435059,
  "cpu_sum_vs_mps_fix_mae": 1.5258167707088433e-07,
  "cpu_sum_vs_mps_sum_max_abs": 7.265512943267822,
  "cpu_sum_vs_mps_fix_max_abs": 9.5367431640625e-07
}

---

torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))

---

torch.autograd.grad(y.sum(), x)

---

PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.3 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.11 (main, Dec 11 2024, 10:25:04) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-26.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M2 Ultra

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.26.0
[conda] numpy                     2.2.6                    pypi_0    pypi
[conda] numpydoc                  1.7.0           py312hca03da5_0
[conda] tbb                       2021.8.0             h48ca7d4_0
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

I found what looks like an MPS autograd bug in PyTorch 2.11.0 on macOS 26.3 running on an Apple M2 Ultra.

For a small standalone SIREN-like network with batched inputs:

y = (torch.sin(x @ w1.T + b1) @ w2.T + b2).squeeze(-1)
g = torch.autograd.grad(y.sum(), x)[0]

the forward pass is fine, but the MPS input gradient is wrong. Most rows become exactly zero. On the same machine:

  • CPU gives the expected nonzero gradient for every batch row.
  • MPS becomes correct if I replace the reduction form with:
g = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]

This was originally discovered because a neural SDF training pipeline on this machine became unstable on MPS while CUDA and another Apple MPS machine behaved correctly. After reducing it, the failure reproduces with raw tensor ops and no project-specific code.

Minimal Reproduction

import json
import torch


def zero_rows(grad):
    return torch.where(grad.abs().sum(dim=1) == 0)[0].detach().cpu().tolist()


def build_case():
    dtype = torch.float32
    batch = 64
    hidden = 16

    torch.manual_seed(123)
    w1 = torch.randn(hidden, 3, dtype=dtype)
    b1 = torch.randn(hidden, dtype=dtype)
    w2 = torch.randn(1, hidden, dtype=dtype)
    b2 = torch.randn(1, dtype=dtype)

    torch.manual_seed(0)
    x = torch.randn(batch, 3, dtype=dtype)
    return x, w1, b1, w2, b2


def gradient(device, use_grad_outputs, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu):
    x = x_cpu.to(device).requires_grad_(True)
    w1 = w1_cpu.to(device)
    b1 = b1_cpu.to(device)
    w2 = w2_cpu.to(device)
    b2 = b2_cpu.to(device)
    y = (torch.sin(x @ w1.T + b1) @ w2.T + b2).squeeze(-1)

    if use_grad_outputs:
        grad = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]
    else:
        grad = torch.autograd.grad(y.sum(), x)[0]
    return grad.detach().cpu()


assert torch.backends.mps.is_available()

x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu = build_case()

cpu_sum = gradient("cpu", False, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
cpu_fix = gradient("cpu", True, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
mps_sum = gradient("mps", False, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)
mps_fix = gradient("mps", True, x_cpu, w1_cpu, b1_cpu, w2_cpu, b2_cpu)

print(json.dumps({
    "torch_version": torch.__version__,
    "torch_git_version": getattr(torch.version, "git_version", None),
    "cpu_sum_zero_rows": zero_rows(cpu_sum),
    "cpu_fix_zero_rows": zero_rows(cpu_fix),
    "mps_sum_zero_rows": zero_rows(mps_sum),
    "mps_fix_zero_rows": zero_rows(mps_fix),
    "cpu_sum_vs_cpu_fix_mae": float((cpu_sum - cpu_fix).abs().mean()),
    "cpu_sum_vs_mps_sum_mae": float((cpu_sum - mps_sum).abs().mean()),
    "cpu_sum_vs_mps_fix_mae": float((cpu_sum - mps_fix).abs().mean()),
    "cpu_sum_vs_mps_sum_max_abs": float((cpu_sum - mps_sum).abs().max()),
    "cpu_sum_vs_mps_fix_max_abs": float((cpu_sum - mps_fix).abs().max()),
}, indent=2))

Actual Behavior

On this machine, torch.autograd.grad(y.sum(), x) on MPS produces zero gradients for most batch rows. The pattern is structured: only every 16th row survives.

Observed output:

{
  "torch_version": "2.11.0",
  "torch_git_version": "70d99e998b4955e0049d13a98d77ae1b14db1f45",
  "cpu_sum_zero_rows": [],
  "cpu_fix_zero_rows": [],
  "mps_sum_zero_rows": [
    1, 2, 3, 4, 5, 6, 7, 8,
    9, 10, 11, 12, 13, 14, 15,
    17, 18, 19, 20, 21, 22, 23, 24,
    25, 26, 27, 28, 29, 30, 31,
    33, 34, 35, 36, 37, 38, 39, 40,
    41, 42, 43, 44, 45, 46, 47,
    49, 50, 51, 52, 53, 54, 55, 56,
    57, 58, 59, 60, 61, 62, 63
  ],
  "mps_fix_zero_rows": [],
  "cpu_sum_vs_cpu_fix_mae": 0.0,
  "cpu_sum_vs_mps_sum_mae": 1.7952971458435059,
  "cpu_sum_vs_mps_fix_mae": 1.5258167707088433e-07,
  "cpu_sum_vs_mps_sum_max_abs": 7.265512943267822,
  "cpu_sum_vs_mps_fix_max_abs": 9.5367431640625e-07
}

Expected Behavior

torch.autograd.grad(y.sum(), x) and torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) should be equivalent here, and MPS should agree with CPU up to normal float32 tolerance.

Instead:

  • CPU behaves correctly for both forms.
  • MPS is wrong for the .sum() form.
  • MPS is correct for the grad_outputs=torch.ones_like(y) form.

Why I Think This Is an MPS Bug

  • The repro uses only raw tensor ops and torch.autograd.grad.
  • Forward outputs are fine.
  • CPU float32 is fine.
  • MPS is fixed by changing only the reduction form passed into autograd.grad.

That suggests the issue is in the MPS backward path for the scalarized reduction case rather than in the forward ops themselves.

Workaround

This avoids the issue for me:

torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))

instead of:

torch.autograd.grad(y.sum(), x)

The same workaround also fixes higher-order derivatives in my downstream code.

Additional Context

I first noticed this through a neural signed-distance-function training workload that depends heavily on input gradients and Hessians. On this machine, the broken autograd.grad(y.sum(), x) path caused unstable training and incorrect geometry derivatives. After reducing the problem, the bug reproduces without any project code.

Versions

PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 26.3 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.6.3.2)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.11 (main, Dec 11 2024, 10:25:04) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-26.3-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M2 Ultra

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.26.0
[conda] numpy                     2.2.6                    pypi_0    pypi
[conda] numpydoc                  1.7.0           py312hca03da5_0
[conda] tbb                       2021.8.0             h48ca7d4_0

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

The most likely fix for the MPS autograd bug in PyTorch is to use torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) instead of torch.autograd.grad(y.sum(), x).

Guidance

  • The issue seems to be specific to the MPS backend on macOS, so using the CPU backend or a different reduction form may work around the problem.
  • The provided workaround torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) fixes the issue for the reporter, and it may be a viable solution for others experiencing similar problems.
  • To verify the fix, compare the gradients computed using the original and workaround methods to ensure they are equivalent within a reasonable tolerance.
  • If the issue persists, try updating PyTorch to a newer version or reporting the bug to the PyTorch developers for further assistance.

Example

y = (torch.sin(x @ w1.T + b1) @ w2.T + b2).squeeze(-1)
g_original = torch.autograd.grad(y.sum(), x)[0]
g_workaround = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]
print(torch.allclose(g_original, g_workaround))  # Should print True if the fix works

Notes

  • The root cause of the issue is unclear, but it appears to be related to the MPS backend's handling of scalarized reductions in the autograd system.
  • The provided workaround may have performance implications, so it's essential to test and verify its effectiveness in your specific use case.

Recommendation

Apply the workaround torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) to avoid the MPS autograd bug, as it has been shown to fix the issue for the reporter and may be a viable solution for others.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING