pytorch - 💡(How to fix) Fix DTensor: `F.nll_loss(weight=..., reduction='mean')` over a sharded batch returns wrong result

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

torch.nn.functional.nll_loss produces a numerically wrong loss when input is DTensor sharded on the batch (first) dimension, target and weight (class weight) are DTensor replicated, reduction='mean', and weight[target] summed per local shard is not equal across ranks (the normal case in production class-imbalance scenarios).

The drift can be several percent relative, well past floating-point tolerance. It vanishes when weight=None, reduction='sum'/'none', or when weight[target] happens to sum equally across ranks.

Error Message

[F.nll_loss] expected=2.400442 actual=2.488206 diff=8.776426e-02 [manual decomposition] expected=2.400442 actual=2.400442 diff=2.38e-07 AssertionError: DTensor F.nll_loss drift: 0.087764

Root Cause

torch.nn.functional.nll_loss produces a numerically wrong loss when input is DTensor sharded on the batch (first) dimension, target and weight (class weight) are DTensor replicated, reduction='mean', and weight[target] summed per local shard is not equal across ranks (the normal case in production class-imbalance scenarios).

The drift can be several percent relative, well past floating-point tolerance. It vanishes when weight=None, reduction='sum'/'none', or when weight[target] happens to sum equally across ranks.

Fix Action

Fix / Workaround

If the snippet above is split into two pastes in an interactive session, the AssertionError from the first paste exits the with block, and the second paste runs outside LocalTensorMode. Without LocalTensorMode active, DTensor dispatches differently and the "manual decomposition" path produces a wrong result too (e.g. 1.7089 instead of 2.4004) -- this is a property of running DTensor outside the LocalTensorMode simulation, not an additional bug. The combined-script form above avoids this trap.

Code Example

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._local_tensor import LocalTensorMode
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard, Replicate

WORLD = 4
dist.init_process_group("fake", rank=0, world_size=WORLD)

with LocalTensorMode(frozenset(range(WORLD))):
    mesh = init_device_mesh("cpu", (WORLD,))
    torch.manual_seed(0)

    log_probs = F.log_softmax(torch.randn(8, 4), dim=1)
    target = torch.tensor([0, 3, 3, 3, 0, 1, 3, 2])
    # Non-uniform class weights so per-rank ``sum(weight[target])`` varies:
    #   rank 0: weight[target[0:2]] sums to 0.62 + 1.39 = 2.01
    #   rank 1:                            1.39 + 1.39 = 2.78
    #   rank 2:                            0.62 + 2.00 = 2.62
    #   rank 3:                            1.39 + 1.60 = 2.99
    class_weight = torch.tensor([0.62, 2.0, 1.6, 1.39])

    d_lp  = distribute_tensor(log_probs,    mesh, [Shard(dim=0)])
    d_tgt = distribute_tensor(target,       mesh, [Replicate()])
    d_w   = distribute_tensor(class_weight, mesh, [Replicate()])

    # --- The buggy path: F.nll_loss ------------------------------
    expected = F.nll_loss(
        log_probs, target, weight=class_weight, reduction="mean",
    )
    actual = F.nll_loss(d_lp, d_tgt, weight=d_w, reduction="mean")
    nll_diff = (actual.full_tensor() - expected).abs().item()
    print(f"[F.nll_loss]           expected={expected.item():.6f} "
          f"actual={actual.full_tensor().item():.6f} diff={nll_diff:.6e}")

    # --- Negative control: manual decomposition into DTensor primitives
    ref = -(class_weight.index_select(0, target)
            * log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
           ).sum() / class_weight.index_select(0, target).sum()
    act = -(d_w.index_select(0, d_tgt)
            * d_lp.gather(1, d_tgt.unsqueeze(1)).squeeze(1)
           ).sum() / d_w.index_select(0, d_tgt).sum()
    decomp_diff = (act.full_tensor() - ref).abs().item()
    print(f"[manual decomposition] expected={ref.item():.6f} "
          f"actual={act.full_tensor().item():.6f} diff={decomp_diff:.2e}")

    # The bug: F.nll_loss drifts ~8.8e-2, manual decomposition is ~1 ULP.
    assert decomp_diff < 1e-5, (
        f"primitive reductions also drift: {decomp_diff:.6e}"
    )
    assert nll_diff < 1e-5, f"DTensor F.nll_loss drift: {nll_diff:.6f}"

---

[F.nll_loss]           expected=2.400442 actual=2.488206 diff=8.776426e-02
[manual decomposition] expected=2.400442 actual=2.400442 diff=2.38e-07
AssertionError: DTensor F.nll_loss drift: 0.087764

---

-sum(w[t[n]] * log_probs[n, t[n]]) / sum(w[t[n]])
RAW_BUFFERClick to expand / collapse

Summary

torch.nn.functional.nll_loss produces a numerically wrong loss when input is DTensor sharded on the batch (first) dimension, target and weight (class weight) are DTensor replicated, reduction='mean', and weight[target] summed per local shard is not equal across ranks (the normal case in production class-imbalance scenarios).

The drift can be several percent relative, well past floating-point tolerance. It vanishes when weight=None, reduction='sum'/'none', or when weight[target] happens to sum equally across ranks.

Minimal reproducer

Single self-contained script. Runs single-process via a fake process group + LocalTensorMode. No cross_entropy, no log_softmax, no linear -- just nll_loss plus a negative control built from primitive DTensor reductions.

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._local_tensor import LocalTensorMode
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard, Replicate

WORLD = 4
dist.init_process_group("fake", rank=0, world_size=WORLD)

with LocalTensorMode(frozenset(range(WORLD))):
    mesh = init_device_mesh("cpu", (WORLD,))
    torch.manual_seed(0)

    log_probs = F.log_softmax(torch.randn(8, 4), dim=1)
    target = torch.tensor([0, 3, 3, 3, 0, 1, 3, 2])
    # Non-uniform class weights so per-rank ``sum(weight[target])`` varies:
    #   rank 0: weight[target[0:2]] sums to 0.62 + 1.39 = 2.01
    #   rank 1:                            1.39 + 1.39 = 2.78
    #   rank 2:                            0.62 + 2.00 = 2.62
    #   rank 3:                            1.39 + 1.60 = 2.99
    class_weight = torch.tensor([0.62, 2.0, 1.6, 1.39])

    d_lp  = distribute_tensor(log_probs,    mesh, [Shard(dim=0)])
    d_tgt = distribute_tensor(target,       mesh, [Replicate()])
    d_w   = distribute_tensor(class_weight, mesh, [Replicate()])

    # --- The buggy path: F.nll_loss ------------------------------
    expected = F.nll_loss(
        log_probs, target, weight=class_weight, reduction="mean",
    )
    actual = F.nll_loss(d_lp, d_tgt, weight=d_w, reduction="mean")
    nll_diff = (actual.full_tensor() - expected).abs().item()
    print(f"[F.nll_loss]           expected={expected.item():.6f} "
          f"actual={actual.full_tensor().item():.6f} diff={nll_diff:.6e}")

    # --- Negative control: manual decomposition into DTensor primitives
    ref = -(class_weight.index_select(0, target)
            * log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
           ).sum() / class_weight.index_select(0, target).sum()
    act = -(d_w.index_select(0, d_tgt)
            * d_lp.gather(1, d_tgt.unsqueeze(1)).squeeze(1)
           ).sum() / d_w.index_select(0, d_tgt).sum()
    decomp_diff = (act.full_tensor() - ref).abs().item()
    print(f"[manual decomposition] expected={ref.item():.6f} "
          f"actual={act.full_tensor().item():.6f} diff={decomp_diff:.2e}")

    # The bug: F.nll_loss drifts ~8.8e-2, manual decomposition is ~1 ULP.
    assert decomp_diff < 1e-5, (
        f"primitive reductions also drift: {decomp_diff:.6e}"
    )
    assert nll_diff < 1e-5, f"DTensor F.nll_loss drift: {nll_diff:.6f}"

Output on main:

[F.nll_loss]           expected=2.400442 actual=2.488206 diff=8.776426e-02
[manual decomposition] expected=2.400442 actual=2.400442 diff=2.38e-07
AssertionError: DTensor F.nll_loss drift: 0.087764

Both expected values are identical (same log_probs / target / class_weight). The math diverges only on the DTensor side and only inside F.nll_loss.

Why this isolates the bug to F.nll_loss's strategy

The two actual computations evaluate the same mathematical formula:

-sum(w[t[n]] * log_probs[n, t[n]]) / sum(w[t[n]])

The manual decomposition gets there through DTensor's strategies for gather, index_select, mul, sum, and scalar division -- and it matches the single-tensor reference to within 1 ULP (2.4e-7). So none of those primitive strategies is responsible.

F.nll_loss lowers to aten::_nll_loss_forward with reduction='mean' and a non-None weight. That op has its own DTensor strategy, and that strategy is the only differentiator between the matching and non-matching paths. It appears to divide by the local sum(weight[target]) per rank instead of all-reducing the numerator and denominator separately and dividing once.

Important: stay inside the with LocalTensorMode block

If the snippet above is split into two pastes in an interactive session, the AssertionError from the first paste exits the with block, and the second paste runs outside LocalTensorMode. Without LocalTensorMode active, DTensor dispatches differently and the "manual decomposition" path produces a wrong result too (e.g. 1.7089 instead of 2.4004) -- this is a property of running DTensor outside the LocalTensorMode simulation, not an additional bug. The combined-script form above avoids this trap.

Affected configurations

Reproduced indirectly (via F.cross_entropy -> F.linear_cross_entropy) in:

  • TestDTensorOpsCPU::test_dtensor_op_db_nn_functional_linear_cross_entropy_cpu_float32
  • TestLocalDTensorOpsCPU::test_dtensor_op_db_nn_functional_linear_cross_entropy_cpu_float32
  • TestMultiThreadedDTensorOpsCPU::test_dtensor_op_db_nn_functional_linear_cross_entropy_cpu_float32

TestCompiledDTensorOps passes -- the compiled lowering apparently decomposes the op into the primitive reductions, which work correctly.

The OpInfo for nn.functional.linear_cross_entropy carries an xfail in dtensor_fails and dtensor_numeric_only_fails referencing this issue (PR #184596 ). Removing those xfails will be the regression test once this is fixed.

Suggested fix direction

The DTensor strategy for aten::_nll_loss_forward with reduction=Reduction.Mean and a non-None weight needs to keep the loss-sum and total-weight-sum as separate sharded accumulators, all-reduce both, then divide once. The current strategy either computes the per-rank mean and averages those across ranks, or normalizes by the global number of samples instead of the global sum(weight[target]).

The same fix likely applies to aten::_nll_loss_backward.

Drafted with assistance from Claude.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING