pytorch - ✅(Solved) Fix [MPS]: nn.LSTM incorrectly applies dropout in eval() mode due to train/eval cache collision. [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180744Fetched 2026-04-19 15:03:52
View on GitHub
Comments
2
Participants
2
Timeline
57
Reactions
0
Timeline (top)
mentioned ×24subscribed ×24labeled ×5commented ×2

Error Message

  • with no runtime error and no loading/saving error.

Root Cause

UPDATE: Root cause found:

Fix Action

Fix / Workaround

Workarounds tried

PR fix notes

PR #180873: [MPS] fix lstm train/eval error

Description (problem / solution / changelog)

Fixes #180744

Changed files

  • aten/src/ATen/native/mps/operations/RnnOps.mm (modified, +1/-1)
  • test/test_mps.py (modified, +21/-0)

Code Example

import torch
import torch.nn as nn

torch.manual_seed(0)
lstm = nn.LSTM(
    input_size=2,
    hidden_size=4,
    num_layers=2,
    dropout=0.1,
    batch_first=True,
).to("mps")

# Real optimizer step
opt = torch.optim.SGD(lstm.parameters(), lr=1e-2)
lstm(torch.randn(3, 5, 2, device="mps"))[0].mean().backward()
opt.step()

lstm.eval()
probe = torch.randn(10, 5, 2, device="mps")

with torch.no_grad():
    full = lstm(probe)[0]      # batch=10
    part = lstm(probe[:3])[0]  # batch=3, same first 3 samples

print((full[:3] - part).abs().max().item())
# Observed on MPS: 0.2898319363594055
# Expected:        0.0 (or <1e-7)

---

PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.7.8 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.0.3
Libc version: N/A

Python version: 3.14.3 (v3.14.3:323c59a5e34, Feb  3 2026, 11:41:37) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-14.7.8-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M2

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[pip3] torchvision==0.26.0
[conda] Could not collect
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

UPDATE: Root cause found:

The differing batch size results are actually just a symptom of a caching collision. _lstm_mps in RnnOps uses a string key to cache the compiled MPSGraph based on input shapes, but fails to include the train flag in the key.

If a user runs a train() pass with shape [3, 5, 2], an MPS graph with hardcoded dropout is generated and cached under that shape. If they later call eval() and pass an input with that exact same [3, 5, 2] shape, we hit the cache and reuses the training graph, incorrectly applying dropout during inference. The bug disappears when changing the batch size in eval() simply because it causes a cache miss, forcing MPS to generate a correct, clean graph without dropout, which causes the difference we see below.

LSTM models on MPS are therefore frequently evaluated incorrectly with dropout

Original Summary (pre-root cause discovery)

On Apple Silicon MPS, nn.LSTM produces different outputs for the same samples in eval() mode, depending on the batch size of the forward call, after at least one real optimizer step.

This appears only for:

  • device='mps'
  • num_layers >= 2
  • dropout > 0
  • at least one real forward --> backward --> optimizer.step()

CPU does not show this behaviour in the same test.

Minimal reproducer

Below is minimal code to reproduce the issue. Switching the "mps" strings to "cpu" removes the issue. This issue is consistent across different seeds and hidden sizes, but the above conditions are all required to trigger it.

import torch
import torch.nn as nn

torch.manual_seed(0)
lstm = nn.LSTM(
    input_size=2,
    hidden_size=4,
    num_layers=2,
    dropout=0.1,
    batch_first=True,
).to("mps")

# Real optimizer step
opt = torch.optim.SGD(lstm.parameters(), lr=1e-2)
lstm(torch.randn(3, 5, 2, device="mps"))[0].mean().backward()
opt.step()

lstm.eval()
probe = torch.randn(10, 5, 2, device="mps")

with torch.no_grad():
    full = lstm(probe)[0]      # batch=10
    part = lstm(probe[:3])[0]  # batch=3, same first 3 samples

print((full[:3] - part).abs().max().item())
# Observed on MPS: 0.2898319363594055
# Expected:        0.0 (or <1e-7)

Expected behaviour

The printed max difference should be 0, or within floating-point numerical-noise level (around 1e-7 to 1e-8), because both forwards use the same weights in eval() mode on the same 3 input samples.

Actual behaviors

On MPS, the difference is large (commonly 0.2 to 0.9).

In my tests with 10 seeds on this environment:

  • MPS, num_layers=2, dropout=0.1: max diffs in [0.177, 0.666]
  • MPS, control dropout=0.0: around 3e-8 to 7e-8
  • MPS, control num_layers=1, dropout=0.1: around 3e-8 to 9e-8
  • CPU, num_layers=2, dropout=0.1: around 0 to 6e-8

This can silently break training/evaluation consistency:

  • training metrics look fine when using a single fixed batch size,
  • but inference or re-evaluation at a different batch size degrades significantly,
  • with no runtime error and no loading/saving error.

Additional notes

Observed behaviors:

  • Reproduces with batch_first=True (not yet exhaustively tested with all layouts).
  • Reproduces at small and large hidden sizes (hidden=4 and hidden=96 both confirmed).
  • Reproduces for bidirectional=True and bidirectional=False; bidirectional makes the diff larger.
  • Removing either trigger condition removes the issue:
    • dropout=0 (no repro)
    • num_layers=1 (no repro)
  • Appears only after a real optimizer step; forward-only in train() mode does not reliably trigger it.

Workarounds tried

  • Train on CPU: fully resolves it. Outputs match across batch sizes/
  • Set dropout=0: fully resolves it

Regression / nightly

I ran the code above against multiple PyTorch versions on the same machine (Apple M2, macOS 14.7.8, Python 3.14.3, seed 0), and every available version produces the same max_diff = 2.898319e-01:

PyTorchmax_diff
2.9.02.898319e-01
2.10.02.898319e-01
2.11.02.898319e-01
2.13.0.dev202604182.898319e-01

Versions

PyTorch version: 2.11.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.7.8 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: version 4.0.3
Libc version: N/A

Python version: 3.14.3 (v3.14.3:323c59a5e34, Feb  3 2026, 11:41:37) [Clang 16.0.0 (clang-1600.0.26.6)] (64-bit runtime)
Python platform: macOS-14.7.8-arm64-arm-64bit-Mach-O
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M2

Versions of relevant libraries:
[pip3] numpy==2.4.4
[pip3] torch==2.11.0
[pip3] torchvision==0.26.0
[conda] Could not collect

cc @mikaylagawarecki @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

The most likely fix is to modify the caching mechanism in _lstm_mps to include the train flag in the cache key to prevent caching collisions.

Guidance

  • Identify the caching mechanism in _lstm_mps and modify it to include the train flag in the cache key.
  • Verify that the modified caching mechanism correctly handles different batch sizes and train modes.
  • Test the fix using the provided minimal reproducer code to ensure that the issue is resolved.
  • Consider adding additional tests to cover different scenarios and edge cases.

Example

# Modified caching mechanism in _lstm_mps
def get_cache_key(input_shape, train_flag):
    return f"{input_shape}_{train_flag}"

# Usage
cache_key = get_cache_key(input_shape, train_flag)

Notes

The provided fix assumes that the caching mechanism in _lstm_mps can be modified to include the train flag in the cache key. If this is not possible, alternative solutions may be needed.

Recommendation

Apply the workaround by modifying the caching mechanism in _lstm_mps to include the train flag in the cache key, as this is the most direct way to address the caching collision issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [MPS]: nn.LSTM incorrectly applies dropout in eval() mode due to train/eval cache collision. [1 pull requests, 2 comments, 2 participants]