pytorch - 💡(How to fix) Fix `torch.distributed.nn.functional.all_reduce` returns incorrect values for graph-connected 0-dim bf16 scalars [2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178865Fetched 2026-04-08 01:57:28
View on GitHub
Comments
2
Participants
3
Timeline
39
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×15subscribed ×15labeled ×5commented ×2

We found that torch.distributed.nn.functional.all_reduce(..., op=SUM) can return an incorrect forward value when the input is a graph-connected 0-dim bf16 scalar, especially when one rank contributes a graph-connected 0.0 and other ranks contribute non-zero scalars.

This is not just normal bf16 rounding. In our reproduction:

  • dist_nn.all_reduce(loss) is wrong on rank0
  • dist_nn.all_reduce(loss.float().reshape(1)) is correct on all ranks
  • dist.all_reduce(loss.detach().float()) is also correct on all ranks

Root Cause

This breaks a legitimate distributed training pattern:

  • each rank computes a scalar loss contribution
  • one rank may legitimately contribute a graph-connected zero scalar
  • the global loss is formed by dist_nn.all_reduce(..., SUM) before backward

When the all-reduced scalar is wrong on one rank, the forward loss already diverges before backward.

Fix Action

Workaround

The following workaround fixes the issue for us without detaching the graph:

global_loss = dist_nn.all_reduce(
    loss.float().reshape(1),
    op=dist.ReduceOp.SUM,
).reshape(())
global_loss = global_loss.to(loss.dtype)

Code Example

import os
import socket
import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn
import torch.multiprocessing as mp

WORLD_SIZE = 4
DTYPE = torch.bfloat16
DIM = 32


def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


def worker(rank: int, world_size: int, port: int):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(port)
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_RANK"] = str(rank)

    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    try:
        torch.manual_seed(1234)
        h = torch.randn(8, DIM, device=f"cuda:{rank}", dtype=DTYPE, requires_grad=True)
        w = torch.randn(DIM, DIM, device=f"cuda:{rank}", dtype=DTYPE, requires_grad=True)
        dist.broadcast(h.detach(), src=0)
        dist.broadcast(w.detach(), src=0)

        # All ranks feed 0-dim scalars into dist_nn.all_reduce.
        # rank0 contributes a graph-connected zero scalar.
        if rank == 0:
            loss = h.sum() * 0.0
        else:
            y = h @ w
            loss = y.float().square().mean()

        bad = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM)
        ok = dist_nn.all_reduce(loss.float().reshape(1), op=dist.ReduceOp.SUM).reshape(())

        ref = loss.detach().float().clone()
        dist.all_reduce(ref, op=dist.ReduceOp.SUM)

        gathered = [None] * world_size if rank == 0 else None
        dist.gather_object(
            {
                "rank": rank,
                "loss": float(loss.detach().float().cpu()),
                "dist_nn_bad": float(bad.detach().float().cpu()),
                "dist_nn_ok": float(ok.detach().float().cpu()),
                "dist_ref": float(ref.detach().float().cpu()),
            },
            gathered,
            dst=0,
        )
        if rank == 0:
            print(gathered)
    finally:
        dist.destroy_process_group()


if __name__ == "__main__":
    port = find_free_port()
    mp.spawn(worker, args=(WORLD_SIZE, port), nprocs=WORLD_SIZE, join=True)

---

python test.py

---

[
  {
    'rank': 0,
    'loss': 0.0,
    'dist_nn_bad': 1.4545795461603218e+29,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 1,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 2,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 3,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
]

---

global_loss = dist_nn.all_reduce(
    loss.float().reshape(1),
    op=dist.ReduceOp.SUM,
).reshape(())
global_loss = global_loss.to(loss.dtype)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Summary

We found that torch.distributed.nn.functional.all_reduce(..., op=SUM) can return an incorrect forward value when the input is a graph-connected 0-dim bf16 scalar, especially when one rank contributes a graph-connected 0.0 and other ranks contribute non-zero scalars.

This is not just normal bf16 rounding. In our reproduction:

  • dist_nn.all_reduce(loss) is wrong on rank0
  • dist_nn.all_reduce(loss.float().reshape(1)) is correct on all ranks
  • dist.all_reduce(loss.detach().float()) is also correct on all ranks

Environment

  • PyTorch: current local install on multi-GPU machine
  • Backend: NCCL
  • GPUs: 4
  • dtype: torch.bfloat16
  • API: torch.distributed.nn.functional.all_reduce

Reliable reproduction

This reproduction does not depend on any external project code.

import os
import socket
import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn
import torch.multiprocessing as mp

WORLD_SIZE = 4
DTYPE = torch.bfloat16
DIM = 32


def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]


def worker(rank: int, world_size: int, port: int):
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(port)
    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_RANK"] = str(rank)

    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    try:
        torch.manual_seed(1234)
        h = torch.randn(8, DIM, device=f"cuda:{rank}", dtype=DTYPE, requires_grad=True)
        w = torch.randn(DIM, DIM, device=f"cuda:{rank}", dtype=DTYPE, requires_grad=True)
        dist.broadcast(h.detach(), src=0)
        dist.broadcast(w.detach(), src=0)

        # All ranks feed 0-dim scalars into dist_nn.all_reduce.
        # rank0 contributes a graph-connected zero scalar.
        if rank == 0:
            loss = h.sum() * 0.0
        else:
            y = h @ w
            loss = y.float().square().mean()

        bad = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM)
        ok = dist_nn.all_reduce(loss.float().reshape(1), op=dist.ReduceOp.SUM).reshape(())

        ref = loss.detach().float().clone()
        dist.all_reduce(ref, op=dist.ReduceOp.SUM)

        gathered = [None] * world_size if rank == 0 else None
        dist.gather_object(
            {
                "rank": rank,
                "loss": float(loss.detach().float().cpu()),
                "dist_nn_bad": float(bad.detach().float().cpu()),
                "dist_nn_ok": float(ok.detach().float().cpu()),
                "dist_ref": float(ref.detach().float().cpu()),
            },
            gathered,
            dst=0,
        )
        if rank == 0:
            print(gathered)
    finally:
        dist.destroy_process_group()


if __name__ == "__main__":
    port = find_free_port()
    mp.spawn(worker, args=(WORLD_SIZE, port), nprocs=WORLD_SIZE, join=True)

Run with:

python test.py

Observed output

A representative output is:

[
  {
    'rank': 0,
    'loss': 0.0,
    'dist_nn_bad': 1.4545795461603218e+29,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 1,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 2,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
  {
    'rank': 3,
    'loss': 36.40619659423828,
    'dist_nn_bad': 109.21858978271484,
    'dist_nn_ok': 109.21858978271484,
    'dist_ref': 109.21858978271484,
  },
]

We also reran the same script multiple times locally. The bad value on rank0 was reproduced consistently, while dist_nn_ok and dist_ref stayed correct.

Expected behavior:

  • dist_nn_bad, dist_nn_ok, and dist_ref should all be equal on all ranks.

Actual behavior:

  • dist_nn_bad is corrupted on rank0.
  • dist_nn_ok and dist_ref are correct on all ranks.

Trigger conditions

The bug appears to require this combination:

  • input to dist_nn.all_reduce is a 0-dim scalar
  • scalar is graph-connected
  • dtype is bf16
  • at least one rank contributes a graph-connected zero scalar
  • other ranks contribute non-zero scalars

All ranks in the reproduction above pass 0-dim scalars into dist_nn.all_reduce; the only special case is that rank0's scalar is a graph-connected 0.0.

Additional narrowing

In the same context:

  • dist_nn.all_reduce(loss) -> wrong
  • dist_nn.all_reduce(loss.float()) -> correct
  • dist_nn.all_reduce(loss.reshape(1)) -> correct
  • dist_nn.all_reduce(loss.float().reshape(1)) -> correct
  • dist.all_reduce(loss.detach().float()) -> correct

This suggests a bug related to 0-dim scalar handling in dist_nn.all_reduce, likely combined with bf16 and autograd metadata.

Workaround

The following workaround fixes the issue for us without detaching the graph:

global_loss = dist_nn.all_reduce(
    loss.float().reshape(1),
    op=dist.ReduceOp.SUM,
).reshape(())
global_loss = global_loss.to(loss.dtype)

Why this matters

This breaks a legitimate distributed training pattern:

  • each rank computes a scalar loss contribution
  • one rank may legitimately contribute a graph-connected zero scalar
  • the global loss is formed by dist_nn.all_reduce(..., SUM) before backward

When the all-reduced scalar is wrong on one rank, the forward loss already diverges before backward.

Versions

PyTorch version: 2.9.1+cu129 Is debug build: False CUDA used to build PyTorch: 12.9 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 Clang version: Could not collect CMake version: version 4.2.1 Libc version: glibc-2.39

Python version: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] (64-bit runtime) Python platform: Linux-4.18.0-193.6.3.el8_2.v1.4.x86_64-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.9.86 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB GPU 1: NVIDIA A100-SXM4-80GB GPU 2: NVIDIA A100-SXM4-80GB GPU 3: NVIDIA A100-SXM4-80GB GPU 4: NVIDIA A100-SXM4-80GB GPU 5: NVIDIA A100-SXM4-80GB GPU 6: NVIDIA A100-SXM4-80GB GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 535.129.03 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.10.2 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.10.2 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8352Y CPU @ 2.20GHz CPU family: 6 Model: 106 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 Stepping: 6 CPU(s) scaling MHz: 65% CPU max MHz: 3400.0000 CPU min MHz: 800.0000 BogoMIPS: 4400.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid md_clear pconfig flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 80 MiB (64 instances) L3 cache: 96 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-31,64-95 NUMA node1 CPU(s): 32-63,96-127 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] mypy_extensions==1.1.0 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.9.1.4 [pip3] nvidia-cuda-cupti-cu12==12.9.79 [pip3] nvidia-cuda-nvrtc-cu12==12.9.86 [pip3] nvidia-cuda-runtime-cu12==12.9.79 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cudnn-frontend==1.16.0 [pip3] nvidia-cufft-cu12==11.4.1.4 [pip3] nvidia-curand-cu12==10.3.10.19 [pip3] nvidia-cusolver-cu12==11.7.5.82 [pip3] nvidia-cusparse-cu12==12.5.10.65 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.9.86 [pip3] nvidia-nvtx-cu12==12.9.79 [pip3] onnx==1.20.0 [pip3] onnx-ir==0.1.13 [pip3] onnxscript==0.5.7 [pip3] torch==2.9.1+cu129 [pip3] torch_memory_saver==0.0.9 [pip3] torch-tb-profiler==0.4.3 [pip3] torchao==0.9.0 [pip3] torchaudio==2.9.1+cu129 [pip3] torchcodec==0.7.0+cu129 [pip3] torchprofile==0.0.4 [pip3] torchvision==0.24.1+cu129 [pip3] transformer_engine_torch==2.10.0 [pip3] triton==3.5.1 [conda] Could not collect

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan

extent analysis

TL;DR

The issue can be fixed by using the provided workaround, which involves converting the loss to a float, reshaping it to a 1D tensor, performing the all-reduce operation, and then reshaping it back to a scalar and converting it back to the original dtype.

Guidance

  • The bug appears to be related to the handling of 0-dim scalars in dist_nn.all_reduce when using bf16 and autograd metadata.
  • The provided workaround fixes the issue without detaching the graph: global_loss = dist_nn.all_reduce(loss.float().reshape(1), op=dist.ReduceOp.SUM).reshape(()).to(loss.dtype).
  • To verify the fix, run the reproduction script with the workaround and check that dist_nn_bad, dist_nn_ok, and dist_ref are equal on all ranks.
  • The trigger conditions for the bug are:
    • Input to dist_nn.all_reduce is a 0-dim scalar.
    • Scalar is graph-connected.
    • Dtype is bf16.
    • At least one rank contributes a graph-connected zero scalar.
    • Other ranks contribute non-zero scalars.

Example

global_loss = dist_nn.all_reduce(
    loss.float().reshape(1),
    op=dist.ReduceOp.SUM,
).reshape(()).to(loss.dtype)

Notes

  • The bug is specific to PyTorch version 2.9.1+cu129 and may not be present in other versions.
  • The workaround may have performance implications, as it involves additional conversions and reshaping.

Recommendation

Apply the provided workaround to fix the issue. This workaround has been verified to produce the correct results in the reproduction script.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.distributed.nn.functional.all_reduce` returns incorrect values for graph-connected 0-dim bf16 scalars [2 comments, 3 participants]