pytorch - ✅(Solved) Fix [MPS] `nn.BatchNorm2d` tensors are not broadcast compatible [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178770Fetched 2026-04-08 01:52:14
View on GitHub
Comments
1
Participants
2
Timeline
78
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×30subscribed ×30labeled ×6referenced ×6

Error Message

import torch import torch.nn as nn

device="mps"

x = torch.rand([2000, 32, 15, 15], device=device, dtype=torch.float16)

model = nn.BatchNorm2d(32, device=device, dtype=torch.float32)

y_mps = model(x) y_cpu = model.cpu()(x.cpu())

print(torch.allclose(y_cpu, y_mps.cpu(), atol=1e-3, rtol=1e-3))

Output:

loc("mps_normalization"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~CKOougDLIumVT7yTlHqeQahAOsLoT_pOj6L2VoE/Library/Caches/com.apple.xbs/TemporaryDirectory.ZOD9bx/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":45:0)): error: output type 'tensor<2000x32x15x15xf16>' and gamma type 'tensor<1x32x1x1xf32>' are not broadcast compatible

LLVM ERROR: Failed to infer result type(s):

"mps.normalization"(...) {} : (tensor<2000x32x15x15xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf32>, tensor<1x32x1x1xf32>) -> ( ??? )

Fix Action

Fixed

PR fix notes

PR #178775: [MPS] Fix BatchNorm with mixed input/weight dtypes (#178770)

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • #178781
  • -> #178775

MPSGraph's normalization ops require all tensors to have matching data types. When input is float16 and weights are float32 (common in mixed-precision training), the operation would crash. Cast weight, bias, and running stats to the appropriate dtype using castMPSTensor before passing to MPSGraph normalization ops, and cast gradients back to weight dtype in the backward pass.

TODO: Remove new test as it's already covered by test_instancenorm_mixed_dtype_backward in test_nn.py

Fixes https://github.com/pytorch/pytorch/issues/178770

Co-authored-by: Claude [email protected]

Changed files

  • aten/src/ATen/native/mps/operations/Normalization.mm (modified, +33/-18)
  • test/test_mps.py (modified, +13/-0)

Code Example

import torch
import torch.nn as nn

device="mps"

x = torch.rand([2000, 32, 15, 15], device=device, dtype=torch.float16)

model = nn.BatchNorm2d(32, device=device, dtype=torch.float32)

y_mps = model(x)
y_cpu = model.cpu()(x.cpu())

print(torch.allclose(y_cpu, y_mps.cpu(), atol=1e-3, rtol=1e-3))

# Output:
# loc("mps_normalization"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~CKOougDLIumVT7yTlHqeQahAOsLoT_pOj6L2VoE/Library/Caches/com.apple.xbs/TemporaryDirectory.ZOD9bx/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":45:0)): error: output type 'tensor<2000x32x15x15xf16>' and gamma type 'tensor<1x32x1x1xf32>' are not broadcast compatible
# LLVM ERROR: Failed to infer result type(s):
# "mps.normalization"(...) {} : (tensor<2000x32x15x15xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf32>, tensor<1x32x1x1xf32>) -> ( ??? )
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Calling nn.BatchNorm2d with different dtypes for the weights and input results in a crash inside MPSGraph. This works on CPU and CUDA. Discovered while working on cifar10 speedrun.

MRE

import torch
import torch.nn as nn

device="mps"

x = torch.rand([2000, 32, 15, 15], device=device, dtype=torch.float16)

model = nn.BatchNorm2d(32, device=device, dtype=torch.float32)

y_mps = model(x)
y_cpu = model.cpu()(x.cpu())

print(torch.allclose(y_cpu, y_mps.cpu(), atol=1e-3, rtol=1e-3))

# Output:
# loc("mps_normalization"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/4~CKOougDLIumVT7yTlHqeQahAOsLoT_pOj6L2VoE/Library/Caches/com.apple.xbs/TemporaryDirectory.ZOD9bx/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":45:0)): error: output type 'tensor<2000x32x15x15xf16>' and gamma type 'tensor<1x32x1x1xf32>' are not broadcast compatible
# LLVM ERROR: Failed to infer result type(s):
# "mps.normalization"(...) {} : (tensor<2000x32x15x15xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf16>, tensor<1x32x1x1xf32>, tensor<1x32x1x1xf32>) -> ( ??? )

Versions

PyTorch version: 2.12.0a0+gitfd1d1b0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.4 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: version 4.1.2 Libc version: N/A

Python version: 3.12.8 | packaged by conda-forge | (main, Dec 5 2024, 14:19:53) [Clang 18.1.8 ] (64-bit runtime) Python platform: macOS-26.4-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M3 Max

Versions of relevant libraries: [pip3] flake8==7.2.0 [pip3] mypy==1.13.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==2.2.6 [pip3] onnx==1.17.0 [pip3] onnx2torch==1.5.15 [pip3] onnxruntime==1.21.1 [pip3] onnxscript==0.1.0.dev20240817 [pip3] optree==0.13.0 [pip3] pytorch_sphinx_theme==0.0.24 [pip3] torch==2.12.0a0+gitfd1d1b0 [pip3] torchaudio==2.8.0 [pip3] torchbench==0.1 [pip3] torchvision==0.27.0a0+9bf794d [conda] numpy 2.2.6 pypi_0 pypi [conda] onnx2torch 1.5.15 pypi_0 pypi [conda] optree 0.13.0 pypi_0 pypi [conda] pytorch-sphinx-theme 0.0.24 pypi_0 pypi [conda] torch 2.12.0a0+gitfd1d1b0 pypi_0 pypi [conda] torchaudio 2.8.0 pypi_0 pypi [conda] torchbench 0.1 dev_0 <develop> [conda] torchvision 0.27.0a0+9bf794d dev_0 <develop>

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

Fix Plan

To fix the issue, we need to ensure that the input and weights of nn.BatchNorm2d have the same dtype.

  • Step 1: Cast the input to the same dtype as the model.

x = torch.rand([2000, 32, 15, 15], device=device, dtype=torch.float32)

*   **Step 2:** Alternatively, cast the model to the same dtype as the input.
    ```python
model = nn.BatchNorm2d(32, device=device, dtype=torch.float16)
  • Step 3: Verify that the dtypes of the input and model are the same.

print(x.dtype == next(model.parameters()).dtype)


### Verification
After applying the fix, the code should run without errors. You can verify this by checking that the output of the model on the MPS device matches the output on the CPU device.
```python
y_mps = model(x)
y_cpu = model.cpu()(x.cpu())
print(torch.allclose(y_cpu, y_mps.cpu(), atol=1e-3, rtol=1e-3))

Extra Tips

  • Always ensure that the dtypes of the input and model are compatible to avoid type mismatch errors.
  • Use torch.allclose to compare the outputs of the model on different devices and ensure that they are similar within a certain tolerance.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING