pytorch - 💡(How to fix) Fix MPS: catastrophically wrong gradients in backward pass (>32K elements) [6 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177116Fetched 2026-04-08 00:22:14
View on GitHub
Comments
6
Participants
3
Timeline
96
Reactions
0
Author
Timeline (top)
mentioned ×41subscribed ×41labeled ×7commented ×6

Code Example

import torch
import torch.nn as nn


class ResidualModel(nn.Module):
    def __init__(self, vocab_size=122, d_model=128):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, d_model)
        self.fc1 = nn.Linear(d_model, d_model, bias=False)
        self.fc2 = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        h = self.emb(x)
        h = h + self.fc1(h)  # residual 1
        h = h + self.fc2(h)  # residual 2
        return self.out(h)


V = 122
seq_len = 4
criterion = nn.CrossEntropyLoss()

# Step 1: prime MPS with a forward+backward at a DIFFERENT batch size.
# Without this step, MPS produces correct gradients.
torch.manual_seed(0)
prime_x = torch.randint(0, V, (4096, seq_len)).to("mps")
prime_y = torch.randint(0, V, (4096, seq_len)).view(-1).to("mps")
torch.manual_seed(0)
prime_m = ResidualModel().to("mps")
prime_loss = criterion(prime_m(prime_x).view(-1, V), prime_y)
prime_loss.backward()
del prime_m, prime_loss, prime_x, prime_y

# Step 2: test at batch=8194 — total elements = 8194*4 = 32776 > 32768
batch_size = 8194
torch.manual_seed(0)
x = torch.randint(0, V, (batch_size, seq_len))
y = torch.randint(0, V, (batch_size, seq_len))

for device in ["cpu", "mps"]:
    results = []
    for trial in range(5):
        torch.manual_seed(0)
        model = ResidualModel().to(device)
        logits = model(x.to(device))
        loss = criterion(logits.view(-1, V), y.view(-1).to(device))
        loss.backward()
        gnorm = sum(
            p.grad.norm().item() ** 2 for p in model.parameters()
        ) ** 0.5
        results.append((loss.item(), gnorm))

    print(f"\n{device.upper()}:")
    for i, (l, g) in enumerate(results):
        print(f"  trial {i}: loss={l:.6f}  grad_norm={g:.4f}")

---

CPU:
  trial 0: loss=5.089585  grad_norm=0.2398
  trial 1: loss=5.089585  grad_norm=0.2398
  trial 2: loss=5.089585  grad_norm=0.2398
  trial 3: loss=5.089585  grad_norm=0.2398
  trial 4: loss=5.089585  grad_norm=0.2398

MPS:
  trial 0: loss=5.089585  grad_norm=0.2414       ← within float32 tolerance
  trial 1: loss=5.089585  grad_norm=3529.657514,719x too large
  trial 2: loss=5.089585  grad_norm=16290.273467,932x too large
  trial 3: loss=5.089585  grad_norm=1996.14918,324x too large
  trial 4: loss=5.089585  grad_norm=16290.273467,932x too large
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

The MPS backend produces catastrophically wrong gradients (1,000x–100,000x too large) during loss.backward() when:

  1. A prior MPS forward+backward pass has occurred in the same process at a different batch size, and
  2. The total number of elements (batch_size × seq_len) exceeds ~32,768 (2^15).

The forward pass is always correct — loss values match CPU within float32 tolerance. Only the backward pass is affected. Without the prior MPS operation at a different shape, gradients are correct. In some quick experimentation, torch.mps.empty_cache() between calls appeared to reduce the failure rate, suggesting MPS buffer pool corruption when tensor shapes change between backward passes.

Possibly related to #116769 (int16 overflow in Metal matmul shaders, fixed for torch.mm/torch.bmm via tiling in PR #117549 but not all code paths), #122045 (F.linear wrong results for large inputs), #117826 (incorrect einsum gradient on MPS).

Minimal reproduction

The model is Embedding → 2 residual Linear blocks → output Linear, trained with CrossEntropyLoss. Step 1 primes MPS at a different batch size; step 2 tests at batch=8194 (32,776 total elements > 32,768).

import torch
import torch.nn as nn


class ResidualModel(nn.Module):
    def __init__(self, vocab_size=122, d_model=128):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, d_model)
        self.fc1 = nn.Linear(d_model, d_model, bias=False)
        self.fc2 = nn.Linear(d_model, d_model, bias=False)
        self.out = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        h = self.emb(x)
        h = h + self.fc1(h)  # residual 1
        h = h + self.fc2(h)  # residual 2
        return self.out(h)


V = 122
seq_len = 4
criterion = nn.CrossEntropyLoss()

# Step 1: prime MPS with a forward+backward at a DIFFERENT batch size.
# Without this step, MPS produces correct gradients.
torch.manual_seed(0)
prime_x = torch.randint(0, V, (4096, seq_len)).to("mps")
prime_y = torch.randint(0, V, (4096, seq_len)).view(-1).to("mps")
torch.manual_seed(0)
prime_m = ResidualModel().to("mps")
prime_loss = criterion(prime_m(prime_x).view(-1, V), prime_y)
prime_loss.backward()
del prime_m, prime_loss, prime_x, prime_y

# Step 2: test at batch=8194 — total elements = 8194*4 = 32776 > 32768
batch_size = 8194
torch.manual_seed(0)
x = torch.randint(0, V, (batch_size, seq_len))
y = torch.randint(0, V, (batch_size, seq_len))

for device in ["cpu", "mps"]:
    results = []
    for trial in range(5):
        torch.manual_seed(0)
        model = ResidualModel().to(device)
        logits = model(x.to(device))
        loss = criterion(logits.view(-1, V), y.view(-1).to(device))
        loss.backward()
        gnorm = sum(
            p.grad.norm().item() ** 2 for p in model.parameters()
        ) ** 0.5
        results.append((loss.item(), gnorm))

    print(f"\n{device.upper()}:")
    for i, (l, g) in enumerate(results):
        print(f"  trial {i}: loss={l:.6f}  grad_norm={g:.4f}")

Expected output

CPU and MPS should produce the same loss and gradient norm across all trials.

Actual output

CPU:
  trial 0: loss=5.089585  grad_norm=0.2398
  trial 1: loss=5.089585  grad_norm=0.2398
  trial 2: loss=5.089585  grad_norm=0.2398
  trial 3: loss=5.089585  grad_norm=0.2398
  trial 4: loss=5.089585  grad_norm=0.2398

MPS:
  trial 0: loss=5.089585  grad_norm=0.2414       ← within float32 tolerance
  trial 1: loss=5.089585  grad_norm=3529.6575     ← 14,719x too large
  trial 2: loss=5.089585  grad_norm=16290.2734    ← 67,932x too large
  trial 3: loss=5.089585  grad_norm=1996.1491     ← 8,324x too large
  trial 4: loss=5.089585  grad_norm=16290.2734    ← 67,932x too large

Loss is identical. Gradient norms are wrong by 1,000x–68,000x on MPS. Trial 0 is correct; trials 1+ are catastrophically wrong (each backward appears to corrupt buffer state for subsequent trials). Without step 1, all MPS trials match CPU.

We posted additional analysis (sweep data across batch sizes and sequence lengths, empty_cache observations, and diagnostic experiments pointing to MPS buffer pool reuse) as a follow-up comment below.

Versions

PyTorch version: 2.10.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.2 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 4.2.0 Libc version: N/A

Python version: 3.12.8 (main, Aug 4 2025, 07:08:45) [Clang 15.0.0 (clang-1500.3.9.4)] (64-bit runtime) Python platform: macOS-26.2-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M2

Versions of relevant libraries: [pip3] numpy==2.3.5 [pip3] torch==2.10.0 [conda] Could not collect

cc @ezyang @gchanan @kadeng @msaroufim @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

1. Empty MPS Cache Between Calls

Empty the MPS cache between calls to prevent buffer pool corruption.

# Before calling loss.backward()
torch.mps.empty_cache()

2. Use torch.mps.set_cache_size() to Increase Buffer Pool Size

Increase the buffer pool size to accommodate larger tensor shapes.

# Before calling loss.backward()
torch.mps.set_cache_size(32768)  # or a larger value

3. Use torch.mps.set_buffer_pool_size() to Increase Buffer Pool Size

Increase the buffer pool size to accommodate larger tensor shapes.

# Before calling loss.backward()
torch.mps.set_buffer_pool_size(32768)  # or a larger value

4. Disable MPS and Use CPU for Large Tensor Shapes

Disable MPS and use CPU for large tensor shapes to avoid buffer pool corruption.

# Before calling loss.backward()
device = torch.device("cpu")
model = ResidualModel().to(device)

Verification

  1. Run the reproduction script with the fix applied.
  2. Verify that the gradient norms are correct on MPS.
  3. Check that the loss values match between CPU and MPS.

Extra Tips

  • Make sure to empty the MPS cache between calls to prevent buffer pool corruption.
  • Increase the buffer pool size using torch.mps.set_cache_size() or torch.mps.set_buffer_pool_size() to accommodate larger tensor shapes.
  • Disable MPS and use CPU for large tensor shapes to avoid buffer pool corruption.
  • Monitor the MPS buffer pool size and adjust it as needed to prevent buffer pool corruption.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix MPS: catastrophically wrong gradients in backward pass (>32K elements) [6 comments, 3 participants]