pytorch - ✅(Solved) Fix [MPS] MPSGraph normalization is significantly slower for rank-5 tensors [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180334Fetched 2026-04-16 06:35:24
View on GitHub
Comments
0
Participants
1
Timeline
68
Reactions
0
Author
Participants
Timeline (top)
mentioned ×29subscribed ×29labeled ×6referenced ×3

Fix Action

Fixed

PR fix notes

PR #180335: [MPS] Flatten 5D tensors to 4D in batch_norm for performance

Description (problem / solution / changelog)

Summary

MPSGraph normalization ops (normalizationWithTensor:, normalizationGradientWithIncomingGradientTensor:) are significantly slower for rank-5 tensors (BatchNorm3d) compared to rank-4 (BatchNorm2d) with the same number of elements.

Since BatchNorm reduces over all dimensions except the channel dim, spatial dimensions can be safely merged: [N, C, D, H, W][N, C, D*H, W]. This PR adds a 5D→4D reshape at the top of batch_norm_mps_out and batch_norm_backward_mps, then recurses into the existing 4D path.

Benchmark (M4 Pro, shape [4, 64, 64, 64, 64])

BeforeAfter
nn.BatchNorm3d fwd+bwd8.7 ms3.5 ms (2.4x)
native_batch_norm (forward)4.5 ms4.5 ms (parity with 4D)
native_batch_norm_backward6.7 ms6.8 ms (parity with 4D)

Test plan

All 5 existing MPS batch_norm tests pass:

test/test_mps.py::TestMPS::test_batch_norm PASSED
test/test_mps.py::TestMPS::test_batch_norm_backward PASSED
test/test_mps.py::TestMPS::test_batch_norm_backward_weight_bias_gradients PASSED
test/test_mps.py::TestMPS::test_batch_norm_mixed_dtype PASSED
test/test_mps.py::TestMPS::test_batch_norm_slices PASSED

Correctness: max diff between 5D and equivalent 4D is 2.38e-07 (float32 machine epsilon).

Fixes #180334

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • aten/src/ATen/native/mps/operations/Normalization.mm (modified, +38/-0)

Code Example

native_batch_norm forward (shape: [4, 64, 64, 64, 64]):
  5D [N,C,D,H,W]:     4.5 ms
  4D [N,C,D*H,W]:     4.5 ms  ← after flatten

native_batch_norm_backward:
  5D [N,C,D,H,W]:     6.7 ms
  4D [N,C,D*H,W]:     6.8 ms  ← after flatten

nn.BatchNorm3d fwd+bwd (before fix):  8.7 ms
nn.BatchNorm2d fwd+bwd (equivalent):  3.5 ms  ← 2.4x faster

---

import time
import torch
import torch.nn as nn

assert torch.backends.mps.is_available()
device = torch.device("mps")

N, C, D, H, W = 4, 64, 64, 64, 64
steps = 20

def bench(label, fn):
    for _ in range(5):
        fn()
    torch.mps.synchronize()
    t0 = time.perf_counter()
    for _ in range(steps):
        fn()
    torch.mps.synchronize()
    ms = (time.perf_counter() - t0) / steps * 1000
    print(f"{label}: {ms:.1f} ms")
    return ms

bn3d = nn.BatchNorm3d(C).to(device).train()
bn2d = nn.BatchNorm2d(C).to(device).train()
bn2d.load_state_dict(bn3d.state_dict())

x5 = torch.randn(N, C, D, H, W, device=device, requires_grad=True)
x4 = torch.randn(N, C, D*H, W, device=device, requires_grad=True)

t5 = bench("BatchNorm3d (5D) fwd+bwd", lambda: bn3d(x5).sum().backward())
t4 = bench("BatchNorm2d (4D) fwd+bwd", lambda: bn2d(x4).sum().backward())
print(f"Ratio: {t5/t4:.1f}x")
RAW_BUFFERClick to expand / collapse

MPSGraph's normalization ops (normalizationWithTensor:, normalizationGradientWithIncomingGradientTensor:) are significantly slower for rank-5 (BatchNorm3d) inputs compared to rank-4 (BatchNorm2d) with the same number of elements.

Since BatchNorm reduces over all dimensions except the channel dim, spatial dimensions can be safely merged: [N, C, D, H, W][N, C, D*H, W].

Benchmark (M4 Pro, PyTorch main @ 0b9170d)

native_batch_norm forward (shape: [4, 64, 64, 64, 64]):
  5D [N,C,D,H,W]:     4.5 ms
  4D [N,C,D*H,W]:     4.5 ms  ← after flatten

native_batch_norm_backward:
  5D [N,C,D,H,W]:     6.7 ms
  4D [N,C,D*H,W]:     6.8 ms  ← after flatten

nn.BatchNorm3d fwd+bwd (before fix):  8.7 ms
nn.BatchNorm2d fwd+bwd (equivalent):  3.5 ms  ← 2.4x faster

Reproducer

import time
import torch
import torch.nn as nn

assert torch.backends.mps.is_available()
device = torch.device("mps")

N, C, D, H, W = 4, 64, 64, 64, 64
steps = 20

def bench(label, fn):
    for _ in range(5):
        fn()
    torch.mps.synchronize()
    t0 = time.perf_counter()
    for _ in range(steps):
        fn()
    torch.mps.synchronize()
    ms = (time.perf_counter() - t0) / steps * 1000
    print(f"{label}: {ms:.1f} ms")
    return ms

bn3d = nn.BatchNorm3d(C).to(device).train()
bn2d = nn.BatchNorm2d(C).to(device).train()
bn2d.load_state_dict(bn3d.state_dict())

x5 = torch.randn(N, C, D, H, W, device=device, requires_grad=True)
x4 = torch.randn(N, C, D*H, W, device=device, requires_grad=True)

t5 = bench("BatchNorm3d (5D) fwd+bwd", lambda: bn3d(x5).sum().backward())
t4 = bench("BatchNorm2d (4D) fwd+bwd", lambda: bn2d(x4).sum().backward())
print(f"Ratio: {t5/t4:.1f}x")

cc @jerryzh168 @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @kulinseth @malfet @DenisVieriu97 @jhavukainen @aditvenk

extent analysis

TL;DR

Flatten the 5D input tensor to 4D by merging spatial dimensions to improve the performance of MPSGraph's normalization ops for BatchNorm3d.

Guidance

  • The performance difference between BatchNorm3d and BatchNorm2d can be mitigated by flattening the 5D input tensor to 4D, as shown in the benchmark results.
  • To verify the fix, run the provided reproducer code and compare the execution times for BatchNorm3d and BatchNorm2d with the flattened input tensor.
  • The flattening can be done by reshaping the input tensor from [N, C, D, H, W] to [N, C, D*H, W], which allows the normalization ops to process the tensor more efficiently.
  • The benchmark results indicate that this flattening approach can achieve a 2.4x speedup for BatchNorm3d compared to the original implementation.

Example

x5 = torch.randn(N, C, D, H, W, device=device, requires_grad=True)
x4 = x5.reshape(N, C, D*H, W)  # flatten the 5D tensor to 4D

Notes

The provided benchmark results and reproducer code suggest that the performance issue is specific to the MPSGraph implementation and can be addressed by flattening the input tensor. However, the underlying cause of the performance difference between BatchNorm3d and BatchNorm2d is not explicitly stated in the issue.

Recommendation

Apply the workaround by flattening the 5D input tensor to 4D, as it has been shown to achieve a significant speedup in the benchmark results.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING