pytorch - ✅(Solved) Fix [DTensor] aten.native_batch_norm_backward.default hangs in sharding propagation [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182739Fetched 2026-05-07 03:30:27
View on GitHub
Comments
0
Participants
1
Timeline
57
Reactions
0
Participants
Assignees
Timeline (top)
mentioned ×24subscribed ×24labeled ×6assigned ×1

Fix Action

Fixed

PR fix notes

PR #182743: [DTensor] add DTensor sharding strategy for batch norm backward

Description (problem / solution / changelog)

Fixes #182739 Register a single-dim strategy for aten.native_batch_norm_backward.default that enables channel-dim (Shard(1)) sharding. Each channel shard computes its own gradients independently, so no cross-shard communication is needed.

Without this, DTensor's decomposition sharding propagator enters an infinite loop when resolving cudnn_batch_norm_backward with Shard(1) inputs, causing the backward pass to hang indefinitely.

The dispatch path is: cudnn_batch_norm_backward (no strategy, decomp fallback) → native_batch_norm_backward (no strategy, decomp fallback again) → infinite loop in to_meta(). This strategy intercepts at the native_batch_norm_backward level, short-circuiting the decomposition.

Authored with Claude.

cc @Skylion007 @ezyang

Changed files

  • test/distributed/tensor/test_math_ops.py (modified, +60/-1)
  • torch/distributed/tensor/_ops/_math_ops.py (modified, +26/-2)

Code Example

"""Minimal repro: DTensor batch norm backward hangs with channel-dim sharding.

Without a registered strategy for aten.native_batch_norm_backward.default,
the decomposition sharding propagator enters an infinite loop in to_meta()
when inputs are Shard(1).

Usage: torchrun --nproc_per_node=2 repro_batchnorm_hang.py
"""
import torch
import torch.nn.functional as F
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh
import torch.distributed as dist

dist.init_process_group("nccl")
torch.cuda.set_device(dist.get_rank())
mesh = init_device_mesh("cuda", (dist.get_world_size(),))

C = dist.get_world_size() * 2
x = distribute_tensor(
    torch.randn(2, C, 3, 3, device="cuda").requires_grad_(True), mesh, [Shard(1)]
)
w = distribute_tensor(torch.randn(C, device="cuda").requires_grad_(True), mesh, [Shard(0)])
b = distribute_tensor(torch.randn(C, device="cuda").requires_grad_(True), mesh, [Shard(0)])
rm = distribute_tensor(torch.zeros(C, device="cuda"), mesh, [Shard(0)])
rv = distribute_tensor(torch.ones(C, device="cuda"), mesh, [Shard(0)])

out = F.batch_norm(x, rm, rv, w, b, training=True)  # forward works
out.sum().backward()  # hangs here
print(f"[rank {dist.get_rank()}] backward done", flush=True)

dist.destroy_process_group()
RAW_BUFFERClick to expand / collapse

I will put up a PR with the fix.

"""Minimal repro: DTensor batch norm backward hangs with channel-dim sharding.

Without a registered strategy for aten.native_batch_norm_backward.default,
the decomposition sharding propagator enters an infinite loop in to_meta()
when inputs are Shard(1).

Usage: torchrun --nproc_per_node=2 repro_batchnorm_hang.py
"""
import torch
import torch.nn.functional as F
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh
import torch.distributed as dist

dist.init_process_group("nccl")
torch.cuda.set_device(dist.get_rank())
mesh = init_device_mesh("cuda", (dist.get_world_size(),))

C = dist.get_world_size() * 2
x = distribute_tensor(
    torch.randn(2, C, 3, 3, device="cuda").requires_grad_(True), mesh, [Shard(1)]
)
w = distribute_tensor(torch.randn(C, device="cuda").requires_grad_(True), mesh, [Shard(0)])
b = distribute_tensor(torch.randn(C, device="cuda").requires_grad_(True), mesh, [Shard(0)])
rm = distribute_tensor(torch.zeros(C, device="cuda"), mesh, [Shard(0)])
rv = distribute_tensor(torch.ones(C, device="cuda"), mesh, [Shard(0)])

out = F.batch_norm(x, rm, rv, w, b, training=True)  # forward works
out.sum().backward()  # hangs here
print(f"[rank {dist.get_rank()}] backward done", flush=True)

dist.destroy_process_group()

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [DTensor] aten.native_batch_norm_backward.default hangs in sharding propagation [1 pull requests, 1 participants]