pytorch - 💡(How to fix) Fix [MPS] Non-deterministic backward pass for F.linear [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181936Fetched 2026-04-30 06:17:45
View on GitHub
Comments
0
Participants
1
Timeline
129
Reactions
0
Participants
Timeline (top)
mentioned ×61subscribed ×61labeled ×7

Code Example

import torch
import torch.nn.functional as F

x = torch.randn(2, 13, 1024, device="mps", dtype=torch.bfloat16)
w = torch.randn(1024, 1024, device="mps", dtype=torch.bfloat16)
grad = torch.randn(2, 13, 1024, device="mps", dtype=torch.bfloat16)

def backward_test():
    x0, w0 = x.clone().requires_grad_(), w.clone()
    y = F.linear(x0, w0)
    y.backward(grad)
    return x0.grad.clone()

first = backward_test()
second = backward_test()
print("Diff between runs:", (second - first).abs().max().item())
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

The backward pass for F.linear(x, w) produces different x.grad across consecutive calls.

This is a follow-up to #180776, which was determined to be specific to the M5. Testing the latest nightly build containing @malfet's fix from #181466, I noticed that the forward pass seems fine, but I am now seeing the same non-deterministic behavior for the backward pass. My assumption is that this occurs under the same conditions as the original issue (BFloat16/Float16, >2D input, M5).

The fix from #181466 should probably be applied to _mps_linear_backward_input to resolve this.

MRE

import torch
import torch.nn.functional as F

x = torch.randn(2, 13, 1024, device="mps", dtype=torch.bfloat16)
w = torch.randn(1024, 1024, device="mps", dtype=torch.bfloat16)
grad = torch.randn(2, 13, 1024, device="mps", dtype=torch.bfloat16)

def backward_test():
    x0, w0 = x.clone().requires_grad_(), w.clone()
    y = F.linear(x0, w0)
    y.backward(grad)
    return x0.grad.clone()

first = backward_test()
second = backward_test()
print("Diff between runs:", (second - first).abs().max().item())

Output: Diff between runs: 130.0

Versions

PyTorch version: 2.13.0.dev20260429 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.4.1 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: Could not collect Libc version: N/A

Python version: 3.14.3 (v3.14.3:323c59a5e34, Feb 3 2026, 11:41:37) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime) Python platform: macOS-26.4.1-arm64-arm-64bit-Mach-O Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M5

Versions of relevant libraries: [pip3] numpy==2.4.4 [pip3] torch==2.13.0.dev20260429 [pip3] torchvision==0.27.0.dev20260429 [conda] Could not collect

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @mruberry @kurtamohler @snadampal @milpuz01 @aditew01 @nikhil-arm @fadara01 @robert-hardwick @nWEIdia @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

Apply the fix from #181466 to _mps_linear_backward_input to resolve the non-deterministic behavior in the backward pass for F.linear(x, w).

Guidance

  • The issue is likely caused by the missing fix in _mps_linear_backward_input, which is similar to the original issue #180776.
  • To verify the fix, run the provided MRE code and check if the difference between consecutive runs is close to zero.
  • Apply the fix from #181466 to _mps_linear_backward_input to resolve the issue.
  • Test the fix with different input sizes and data types to ensure the non-deterministic behavior is resolved.

Example

No code example is provided as the fix is specific to the PyTorch internal implementation.

Notes

The issue is specific to the M5 CPU, BFloat16/Float16 data types, and input sizes greater than 2D. The fix from #181466 should be applied to _mps_linear_backward_input to resolve the issue.

Recommendation

Apply the workaround by applying the fix from #181466 to _mps_linear_backward_input, as it is likely to resolve the non-deterministic behavior in the backward pass.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING