pytorch - 💡(How to fix) Fix `torch.nn.functional.linear` produces incorrect results on MPS with AMD RDNA2 GPUs (RX 6000 series) [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178697Fetched 2026-04-08 01:45:04
View on GitHub
Comments
1
Participants
2
Timeline
43
Reactions
0
Author
Participants
Timeline (top)
mentioned ×17subscribed ×17labeled ×6closed ×1

Error Message

torch.nn.functional.linear produces catastrophically wrong results on MPS when running on AMD RDNA2 GPUs (e.g. Radeon RX 6900 XT in a Mac Pro). The error is enormous (max difference of 30-90+ vs expected ~0.0001) and makes all neural network inference produce garbage output.

Root Cause

The RDNA2 GPU architecture was never shipped in any Mac by Apple. Apple's Metal compute shaders for the MPS backend were developed and tested against GCN-era GPUs (Vega 56/64 in Mac Pro, various Polaris in iMacs). The Metal kernel used by F.linear likely takes a different code path than torch.mm and that path has an accumulator or dispatch bug specific to RDNA2's compute units.

Key evidence:

  • torch.mm → correct results (uses one Metal kernel)
  • F.linear → broken results for dim>=256 (uses a different Metal kernel)
  • Both float16 and float32 are affected
  • Data transfer is fine (CPU→MPS→CPU roundtrip is exact)
  • Element-wise ops (add, mul, exp) are fine
  • torch.bmm is fine
  • The same system worked perfectly with a Vega 64 GPU (GCN architecture)

Fix Action

Workaround

Monkey-patching F.linear to use torch.mm with a contiguous transposed weight completely fixes the issue:

import torch

_original_linear = torch.nn.functional.linear

def patched_linear(input, weight, bias=None):
    if input.device.type == "mps" and input.shape[-1] >= 256:
        orig_shape = input.shape
        input_2d = input.reshape(-1, orig_shape[-1])
        output = torch.mm(input_2d, weight.t().contiguous())
        output = output.reshape(*orig_shape[:-1], weight.shape[0])
        if bias is not None:
            output = output + bias
        return output
    return _original_linear(input, weight, bias)

torch.nn.functional.linear = patched_linear

With this patch, Stable Diffusion XL generates correct images on MPS at ~1.8 it/s. Without it, every image is noise/garbage.

Code Example

import torch

torch.manual_seed(42)

x = torch.randn(1, 320)
w = torch.randn(1280, 320)

# F.linearBROKEN on MPS/RDNA2
r_cpu = torch.nn.functional.linear(x, w)
r_mps = torch.nn.functional.linear(x.to("mps"), w.to("mps")).cpu()
print(f"F.linear max diff: {(r_cpu - r_mps).abs().max().item():.6f}")
# Expected: ~0.000010
# Actual on RDNA2: ~45.0 (!!!)

# torch.mm with same data — WORKS FINE
r_cpu2 = torch.mm(x, w.t().contiguous())
r_mps2 = torch.mm(x.to("mps"), w.t().contiguous().to("mps")).cpu()
print(f"torch.mm max diff: {(r_cpu2 - r_mps2).abs().max().item():.6f}")
# Result: ~0.000010 (correct)

---

import torch

_original_linear = torch.nn.functional.linear

def patched_linear(input, weight, bias=None):
    if input.device.type == "mps" and input.shape[-1] >= 256:
        orig_shape = input.shape
        input_2d = input.reshape(-1, orig_shape[-1])
        output = torch.mm(input_2d, weight.t().contiguous())
        output = output.reshape(*orig_shape[:-1], weight.shape[0])
        if bias is not None:
            output = output + bias
        return output
    return _original_linear(input, weight, bias)

torch.nn.functional.linear = patched_linear
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Bug Description

torch.nn.functional.linear produces catastrophically wrong results on MPS when running on AMD RDNA2 GPUs (e.g. Radeon RX 6900 XT in a Mac Pro). The error is enormous (max difference of 30-90+ vs expected ~0.0001) and makes all neural network inference produce garbage output.

torch.mm with identical matrices works correctly. The bug is isolated to the specific Metal kernel dispatched by F.linear.

This completely breaks Stable Diffusion, LLM inference, and any model using Linear layers on MPS with RDNA2 GPUs.

Reproduction

import torch

torch.manual_seed(42)

x = torch.randn(1, 320)
w = torch.randn(1280, 320)

# F.linear — BROKEN on MPS/RDNA2
r_cpu = torch.nn.functional.linear(x, w)
r_mps = torch.nn.functional.linear(x.to("mps"), w.to("mps")).cpu()
print(f"F.linear max diff: {(r_cpu - r_mps).abs().max().item():.6f}")
# Expected: ~0.000010
# Actual on RDNA2: ~45.0 (!!!)

# torch.mm with same data — WORKS FINE
r_cpu2 = torch.mm(x, w.t().contiguous())
r_mps2 = torch.mm(x.to("mps"), w.t().contiguous().to("mps")).cpu()
print(f"torch.mm max diff: {(r_cpu2 - r_mps2).abs().max().item():.6f}")
# Result: ~0.000010 (correct)

Size dependency

The bug only manifests when input_features >= 256:

Shape (in→out)F.linear max_diffStatus
128→1280.000004OK
256→25632.16BROKEN
320→32035.80BROKEN
320→128041.17BROKEN
1280→128094.39BROKEN
768→204883.59BROKEN

Workaround

Monkey-patching F.linear to use torch.mm with a contiguous transposed weight completely fixes the issue:

import torch

_original_linear = torch.nn.functional.linear

def patched_linear(input, weight, bias=None):
    if input.device.type == "mps" and input.shape[-1] >= 256:
        orig_shape = input.shape
        input_2d = input.reshape(-1, orig_shape[-1])
        output = torch.mm(input_2d, weight.t().contiguous())
        output = output.reshape(*orig_shape[:-1], weight.shape[0])
        if bias is not None:
            output = output + bias
        return output
    return _original_linear(input, weight, bias)

torch.nn.functional.linear = patched_linear

With this patch, Stable Diffusion XL generates correct images on MPS at ~1.8 it/s. Without it, every image is noise/garbage.

Environment

  • GPU: AMD Radeon RX 6900 XT (RDNA2 architecture, Device ID 0x73bf)
  • System: Mac Pro (2019), Intel i9-14900K, 128 GB RAM
  • macOS: 15.7.4 (Sequoia)
  • PyTorch: 2.2.2 (also reproduced with 2.7.1 built from source)
  • Python: 3.11.12

Root Cause Analysis

The RDNA2 GPU architecture was never shipped in any Mac by Apple. Apple's Metal compute shaders for the MPS backend were developed and tested against GCN-era GPUs (Vega 56/64 in Mac Pro, various Polaris in iMacs). The Metal kernel used by F.linear likely takes a different code path than torch.mm and that path has an accumulator or dispatch bug specific to RDNA2's compute units.

Key evidence:

  • torch.mm → correct results (uses one Metal kernel)
  • F.linear → broken results for dim>=256 (uses a different Metal kernel)
  • Both float16 and float32 are affected
  • Data transfer is fine (CPU→MPS→CPU roundtrip is exact)
  • Element-wise ops (add, mul, exp) are fine
  • torch.bmm is fine
  • The same system worked perfectly with a Vega 64 GPU (GCN architecture)

Impact

Any Mac Pro (or Hackintosh) user with an aftermarket AMD RDNA/RDNA2 GPU (RX 5000/6000/6000 series) is affected. MPS appears to work (basic ops pass, torch.backends.mps.is_available() returns True), but all neural network inference silently produces wrong results. This is extremely difficult to diagnose because:

  1. No errors are thrown
  2. Individual small ops appear to work
  3. The failure only manifests at scale (dim >= 256)
  4. Output looks like "noise" which users may attribute to wrong models/settings

Versions

Collecting environment information... PyTorch version: 2.2.2 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 15.7.4 (x86_64) GCC version: Could not collect Clang version: 17.0.0 (clang-1700.0.13.5) CMake version: version 4.3.0 Libc version: N/A

Python version: 3.11.12 (main, May 30 2025, 06:05:03) [Clang 20.1.4 ] (64-bit runtime) Python platform: macOS-15.7.4-x86_64-i386-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Intel(R) Core(TM) i9-14900K

Versions of relevant libraries: [pip3] clip-anytorch==2.6.0 [pip3] numpy==1.26.4 [pip3] onnx==1.16.1 [pip3] onnxruntime==1.19.2 [pip3] optree==0.19.0 [pip3] pytorch-lightning==2.1.3 [pip3] torch==2.2.2 [pip3] torchmetrics==1.9.0 [pip3] torchsde==0.2.6 [pip3] torchvision==0.17.2 [conda] Could not collect

cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

To fix the issue with torch.nn.functional.linear on MPS with RDNA2 GPUs, we can use the provided monkey patch. Here are the steps:

  • Replace the original torch.nn.functional.linear function with the patched version.
  • The patched version checks if the input device is MPS and the input shape is greater than or equal to 256.
  • If the conditions are met, it uses torch.mm with a contiguous transposed weight to compute the linear transformation.
import torch

_original_linear = torch.nn.functional.linear

def patched_linear(input, weight, bias=None):
    if input.device.type == "mps" and input.shape[-1] >= 256:
        orig_shape = input.shape
        input_2d = input.reshape(-1, orig_shape[-1])
        output = torch.mm(input_2d, weight.t().contiguous())
        output = output.reshape(*orig_shape[:-1], weight.shape[0])
        if bias is not None:
            output = output + bias
        return output
    return _original_linear(input, weight, bias)

torch.nn.functional.linear = patched_linear

Verification

To verify that the fix worked, you can run the reproduction code provided in the issue body:

import torch

torch.manual_seed(42)

x = torch.randn(1, 320)
w = torch.randn(1280, 320)

r_cpu = torch.nn.functional.linear(x, w)
r_mps = torch.nn.functional.linear(x.to("mps"), w.to("mps")).cpu()
print(f"F.linear max diff: {(r_cpu - r_mps).abs().max().item():.6f}")

The output should be close to the expected value of ~0.000010.

Extra Tips

  • Make sure to apply the monkey patch before running any code that uses torch.nn.functional.linear.
  • If you are using a different version of PyTorch, you may need to modify the patch accordingly.
  • This fix is specific to the RDNA2 GPU architecture and may not be necessary for other architectures.
  • It's recommended to test the fix thoroughly to ensure it works correctly for your specific use case.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.nn.functional.linear` produces incorrect results on MPS with AMD RDNA2 GPUs (RX 6000 series) [1 comments, 2 participants]