pytorch - ✅(Solved) Fix [DTensor] clip_grad_norm_ fails when per-parameter norms come from different DeviceMesh objects [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180346Fetched 2026-04-16 06:35:10
View on GitHub
Comments
1
Participants
2
Timeline
30
Reactions
0
Author
Participants
Timeline (top)
mentioned ×13subscribed ×13labeled ×2commented ×1

Root Cause

Actual: the generic torch.stack path fails because the norm DTensors carry different mesh metadata. This follows directly from the current _get_total_norm() implementation.

PR fix notes

PR #180501: Fix clip_grad_norm_ failure with mixed DeviceMesh DTensor norms

Description (problem / solution / changelog)

Summary

Fixes #180346

_get_total_norm() computes per-parameter gradient norms and stacks them via torch.stack(). When gradients are DTensors, the resulting norms are also DTensors that retain mesh metadata. If those norms come from different DeviceMesh objects, torch.stack() fails because it cannot reconcile the different mesh metadata, even though the underlying tensors share the same device and dtype.

This PR materialises DTensor norms to plain tensors via full_tensor() before stacking, but only when the norms actually come from different meshes (or are a mix of DTensor and plain tensors). The common single-mesh path (FSDP, TP) is untouched and pays no extra cost.

Benchmark (256 gradients, single rank, CPU)

CaseBeforeAfter
Plain tensors75 µs85 µs
Single mesh DTensors250 µs269 µs
Mixed mesh DTensorsCRASH261 µs

Test plan

  • New test test_clip_grad_norm_mixed_mesh in test/distributed/tensor/test_dtensor.py — creates parameters on two different DeviceMesh objects, manufactures gradients, and verifies clip_grad_norm_ returns the correct total norm
  • Verified plain tensor path is unaffected
  • Verified single-mesh DTensor path is unaffected
  • Benchmarked against naive fix (always full_tensor())

Changed files

  • test/distributed/tensor/test_dtensor.py (modified, +36/-0)
  • torch/nn/utils/clip_grad.py (modified, +34/-4)

Code Example

import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.device_mesh import DeviceMesh
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms

WORLD_SIZE = 2

@spawn_threads_and_init_comms(world_size=WORLD_SIZE)
def repro_mixed_mesh_clip_grad_norm_failure(world_size: int) -> None:
    rank = dist.get_rank()

    mesh_a = DeviceMesh("cpu", torch.arange(world_size))
    mesh_b = DeviceMesh("cpu", torch.arange(world_size - 1, -1, -1))

    placements = [Replicate()]

    local_a = torch.tensor([3.0 + rank], dtype=torch.float32)
    local_b = torch.tensor([4.0 + rank], dtype=torch.float32)

    dt_a = DTensor.from_local(local_a, mesh_a, placements)
    dt_b = DTensor.from_local(local_b, mesh_b, placements)

    norm_type = 2.0
    norms = [torch.linalg.vector_norm(g, norm_type) for g in [dt_a, dt_b]]

    first_device = dt_a.device
    total_norm = torch.linalg.vector_norm(
        torch.stack([norm.to(first_device) for norm in norms]),
        norm_type,
    )
    print(rank, total_norm)

if __name__ == "__main__":
    repro_mixed_mesh_clip_grad_norm_failure(WORLD_SIZE)
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

torch.nn.utils.clip_grad._get_total_norm() groups tensors by (device, dtype) only, computes per-tensor norms, and then stacks those norms together. For DTensor inputs, the norms themselves remain DTensors and retain mesh metadata, even if such tensors are scalars ultimately. If the norms come from different DeviceMesh objects, the final torch.stack(...) fails even though the tensors share the same device and dtype.

Expected: clip_grad_norm_ handles mixed meshes correctly. I considered extracting DTensor scalar norms via item(), but that only seems safe when each norm is known to be a replicated scalar. For general DTensor layouts, item() reads the local value and can be numerically wrong, so it does not look like a sound generic fix.

Actual: the generic torch.stack path fails because the norm DTensors carry different mesh metadata. This follows directly from the current _get_total_norm() implementation.

Alternatives

No response

Additional context

This seems distinct from the already reported performance issue in the same code path (https://github.com/pytorch/pytorch/issues/169445): this report is about a runtime failure with mixed meshes, not about stack complexity.

Related https://github.com/pytorch/pytorch/issues/121020.

Reproduction

import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.device_mesh import DeviceMesh
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms

WORLD_SIZE = 2

@spawn_threads_and_init_comms(world_size=WORLD_SIZE)
def repro_mixed_mesh_clip_grad_norm_failure(world_size: int) -> None:
    rank = dist.get_rank()

    mesh_a = DeviceMesh("cpu", torch.arange(world_size))
    mesh_b = DeviceMesh("cpu", torch.arange(world_size - 1, -1, -1))

    placements = [Replicate()]

    local_a = torch.tensor([3.0 + rank], dtype=torch.float32)
    local_b = torch.tensor([4.0 + rank], dtype=torch.float32)

    dt_a = DTensor.from_local(local_a, mesh_a, placements)
    dt_b = DTensor.from_local(local_b, mesh_b, placements)

    norm_type = 2.0
    norms = [torch.linalg.vector_norm(g, norm_type) for g in [dt_a, dt_b]]

    first_device = dt_a.device
    total_norm = torch.linalg.vector_norm(
        torch.stack([norm.to(first_device) for norm in norms]),
        norm_type,
    )
    print(rank, total_norm)

if __name__ == "__main__":
    repro_mixed_mesh_clip_grad_norm_failure(WORLD_SIZE)

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy

extent analysis

TL;DR

The most likely fix is to modify the _get_total_norm() function to handle DTensor inputs with different mesh metadata by extracting scalar norms in a way that preserves numerical accuracy.

Guidance

  • Investigate modifying the _get_total_norm() function to handle DTensor inputs with different mesh metadata, potentially by using a method that extracts scalar norms without relying on item().
  • Consider using a different approach to compute the total norm, such as using torch.distributed.all_gather to collect norms from all devices and then computing the total norm.
  • Verify that the modified function correctly handles mixed meshes by testing it with the provided reproduction code.
  • Review related issues, such as https://github.com/pytorch/pytorch/issues/121020, to ensure that the fix does not introduce any regressions.

Example

# Example of how to extract scalar norms without using item()
norms = [torch.linalg.vector_norm(g, norm_type) for g in [dt_a, dt_b]]
total_norm = torch.linalg.vector_norm(
    torch.stack([norm.to(first_device) for norm in norms]),
    norm_type,
)
# Replace the above with a custom implementation that handles DTensor norms

Notes

The provided reproduction code demonstrates the issue with mixed meshes, but a complete fix may require additional modifications to the PyTorch library.

Recommendation

Apply a workaround by modifying the _get_total_norm() function to handle DTensor inputs with different mesh metadata, as this is a specific issue that requires a targeted fix.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING