pytorch - ✅(Solved) Fix Bad interaction with autograd layout invariant and DTensor leaf parameters when grad is transposed [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180486Fetched 2026-04-17 08:22:12
View on GitHub
Comments
0
Participants
1
Timeline
47
Reactions
0
Author
Participants
Assignees
Timeline (top)
mentioned ×18subscribed ×18labeled ×4referenced ×4

Root Cause

The proximal cause of the problem is a bad interaction of autograd and DTensor. Concretely, PyTorch enforces that the memory layout of a parameter matches that of its gradient. Because you have a contiguous parameter but a non-contiguous (transposed) gradient, the layout invariant fails, causing us to attempt to make the tensor contiguous. However, the autograd engine does this in an interesting way:

Fix Action

Fixed

PR fix notes

PR #180511: Fix DTensor Partial placement lost during autograd layout invariant (issue #180486)

Description (problem / solution / changelog)

Summary Addresses https://github.com/pytorch/pytorch/issues/180486. When a Replicate DTensor parameter produces a non-contiguous Partial gradient (e.g. via transpose), autograd's clone_obey_contract calls new_empty_strided + copy_ to fix strides. new_empty_strided defaulted to Replicate placement, causing copy_ to trigger an unwanted all-reduce from Partial to Replicate.

Allow new_empty/new_empty_strided to inherit Partial placement from the input when shapes match, since uninitialized memory is immediately overwritten. Other factories (new_zeros, new_ones, new_full) keep the existing Replicate default to avoid incorrect values after reduction.

Test Cases

  1. pytest /data/users/anshulsi/pytorch/test/distributed/tensor/test_tensor_ops.py -k test_backward_partial_grad_with_transpose

Stack from ghstack (oldest at bottom):

  • -> #180511

Changed files

  • test/distributed/tensor/test_tensor_ops.py (modified, +52/-0)
  • torch/distributed/tensor/_ops/_tensor_ops.py (modified, +11/-1)

Code Example

import os

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Partial, Replicate


class CustomMatmul(torch.autograd.Function):
    """Custom autograd.Function that does the same thing as torch.mm."""

    @staticmethod
    def forward(ctx, x, w):
        ctx.save_for_backward(x, w)
        return torch.mm(x, w)

    @staticmethod
    def backward(ctx, grad_output):
        x, w = ctx.saved_tensors
        return grad_output @ w.T, x.T @ grad_output


class TestModel(nn.Module):
    def __init__(self, use_transpose: bool, use_custom_fn: bool):
        super().__init__()
        # 3D weight to allow .transpose(1, 2)
        self.weight = nn.Parameter(torch.randn(4, 8, 8))
        self.use_transpose = use_transpose
        self.use_custom_fn = use_custom_fn
        self._grad_placement = None

    def forward(self, x):
        w = self.weight

        if isinstance(w, DTensor):
            grad_placements = [
                Partial() if isinstance(p, Replicate) else p for p in w.placements
            ]
            w = w.to_local(grad_placements=grad_placements)

            def _capture(param, model=self):
                if param.grad is not None and isinstance(param.grad, DTensor):
                    model._grad_placement = param.grad.placements

            self.weight.register_post_accumulate_grad_hook(_capture)

        if self.use_transpose:
            # .transpose() backward produces non-contiguous grad
            w = w.transpose(1, 2).contiguous()

        # Use first "expert" slice for matmul
        if self.use_custom_fn:
            return CustomMatmul.apply(x, w[0])
        else:
            return torch.mm(x, w[0])


def main():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    # 1D mesh — parameters will be Replicate on this mesh
    mesh = init_device_mesh("cuda", (world_size,))

    torch.manual_seed(0)

    results = []
    for use_transpose in [False, True]:
        for use_custom_fn in [False, True]:
            t_label = "+transpose" if use_transpose else "           "
            fn_label = "custom_fn" if use_custom_fn else "native_mm"
            label = f"{fn_label} {t_label}"

            model = TestModel(
                use_transpose=use_transpose, use_custom_fn=use_custom_fn
            ).cuda()

            # Directly make weight a Replicate DTensor (no FSDP2)
            with torch.no_grad():
                model.weight = nn.Parameter(
                    DTensor.from_local(
                        model.weight.data,
                        device_mesh=mesh,
                        placements=[Replicate()],
                    )
                )

            x = torch.randn(2, 8, device="cuda", requires_grad=True)
            out = model(x)
            try:
                out.sum().backward()
                failed = False
            except AssertionError as e:
                failed = True
                if rank == 0:
                    print(f"  {label}: FAILED (assertion: {e})")

            p = model._grad_placement
            is_partial = (
                p is not None
                and all(isinstance(pl, Partial) for pl in p)
            )

            if rank == 0:
                if failed:
                    status = "FAILED (assertion)"
                elif is_partial:
                    status = f"ok   grad={p}"
                else:
                    status = f"BUG  grad={p}"
                print(f"  {label}: {status}")
                results.append((label, is_partial and not failed))

    if rank == 0:
        print()
        all_pass = all(ok for _, ok in results)
        if all_pass:
            print("All passed — bug not reproduced with upstream DTensor.")
        else:
            print("Bug reproduced with upstream DTensor:")
            print(".transpose() causes non-contiguous grad in backward,")
            print("which triggers Partial→Replicate redistribution in AccumulateGrad.")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

---

return std::move(new_grad
      .new_empty_strided_symint(
          variable.sym_sizes(),
          variable.sym_strides(),
          ...)
      .copy_(new_grad));
RAW_BUFFERClick to expand / collapse

This example is distilled from an internal codebase that has a more DTensor-centric version of FSDP2.

import os

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Partial, Replicate


class CustomMatmul(torch.autograd.Function):
    """Custom autograd.Function that does the same thing as torch.mm."""

    @staticmethod
    def forward(ctx, x, w):
        ctx.save_for_backward(x, w)
        return torch.mm(x, w)

    @staticmethod
    def backward(ctx, grad_output):
        x, w = ctx.saved_tensors
        return grad_output @ w.T, x.T @ grad_output


class TestModel(nn.Module):
    def __init__(self, use_transpose: bool, use_custom_fn: bool):
        super().__init__()
        # 3D weight to allow .transpose(1, 2)
        self.weight = nn.Parameter(torch.randn(4, 8, 8))
        self.use_transpose = use_transpose
        self.use_custom_fn = use_custom_fn
        self._grad_placement = None

    def forward(self, x):
        w = self.weight

        if isinstance(w, DTensor):
            grad_placements = [
                Partial() if isinstance(p, Replicate) else p for p in w.placements
            ]
            w = w.to_local(grad_placements=grad_placements)

            def _capture(param, model=self):
                if param.grad is not None and isinstance(param.grad, DTensor):
                    model._grad_placement = param.grad.placements

            self.weight.register_post_accumulate_grad_hook(_capture)

        if self.use_transpose:
            # .transpose() backward produces non-contiguous grad
            w = w.transpose(1, 2).contiguous()

        # Use first "expert" slice for matmul
        if self.use_custom_fn:
            return CustomMatmul.apply(x, w[0])
        else:
            return torch.mm(x, w[0])


def main():
    dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    torch.cuda.set_device(rank)

    # 1D mesh — parameters will be Replicate on this mesh
    mesh = init_device_mesh("cuda", (world_size,))

    torch.manual_seed(0)

    results = []
    for use_transpose in [False, True]:
        for use_custom_fn in [False, True]:
            t_label = "+transpose" if use_transpose else "           "
            fn_label = "custom_fn" if use_custom_fn else "native_mm"
            label = f"{fn_label} {t_label}"

            model = TestModel(
                use_transpose=use_transpose, use_custom_fn=use_custom_fn
            ).cuda()

            # Directly make weight a Replicate DTensor (no FSDP2)
            with torch.no_grad():
                model.weight = nn.Parameter(
                    DTensor.from_local(
                        model.weight.data,
                        device_mesh=mesh,
                        placements=[Replicate()],
                    )
                )

            x = torch.randn(2, 8, device="cuda", requires_grad=True)
            out = model(x)
            try:
                out.sum().backward()
                failed = False
            except AssertionError as e:
                failed = True
                if rank == 0:
                    print(f"  {label}: FAILED (assertion: {e})")

            p = model._grad_placement
            is_partial = (
                p is not None
                and all(isinstance(pl, Partial) for pl in p)
            )

            if rank == 0:
                if failed:
                    status = "FAILED (assertion)"
                elif is_partial:
                    status = f"ok   grad={p}"
                else:
                    status = f"BUG  grad={p}"
                print(f"  {label}: {status}")
                results.append((label, is_partial and not failed))

    if rank == 0:
        print()
        all_pass = all(ok for _, ok in results)
        if all_pass:
            print("All passed — bug not reproduced with upstream DTensor.")
        else:
            print("Bug reproduced with upstream DTensor:")
            print(".transpose() causes non-contiguous grad in backward,")
            print("which triggers Partial→Replicate redistribution in AccumulateGrad.")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

The proximal cause of the problem is a bad interaction of autograd and DTensor. Concretely, PyTorch enforces that the memory layout of a parameter matches that of its gradient. Because you have a contiguous parameter but a non-contiguous (transposed) gradient, the layout invariant fails, causing us to attempt to make the tensor contiguous. However, the autograd engine does this in an interesting way:

  return std::move(new_grad
      .new_empty_strided_symint(
          variable.sym_sizes(),
          variable.sym_strides(),
          ...)
      .copy_(new_grad));

This is awkward for DTensor, because when new_empty_strided_symint is called we end up with a Replicate tensor by default, but the copy_ must somehow transmute it into a Partial tensor. In presence of aliasing, this isn't correct, and because DTensor doesn't know that it's safe, it triggers a redistribute.

Stock FSDP2 doesn't seem to have this problem, because the all gather is into a plain Tensor, so the FSDP2 hooks get to take care of the grads before they get to a leaf tensor, so the layout invariant won't trigger a redistribute, just a contiguous() call (which is a slight inefficiency, but not a stinker like this one is).

@albanD's opinion was that DTensor should know that the fresh tensor has no aliases and then we can transmute it to Partial in this situation. Uhh... sure. I think any localized fix is probably OK here.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

TL;DR

The most likely fix for the issue is to modify the DTensor handling in the autograd engine to safely transmute the tensor to Partial when creating a new gradient tensor with the same layout as the parameter.

Guidance

  • Identify the specific line of code where the error occurs, which is in the autograd engine when creating a new gradient tensor with new_empty_strided_symint.
  • Consider modifying the DTensor implementation to handle this case, potentially by adding a check for aliasing and allowing the transmutation to Partial when safe.
  • Review the FSDP2 implementation to understand how it handles similar situations and see if any lessons can be applied to the DTensor case.
  • Test any proposed fixes thoroughly to ensure they do not introduce new issues or regressions.

Example

No specific code example is provided, as the issue is related to the internal implementation of the autograd engine and DTensor. However, any fix would likely involve modifying the new_empty_strided_symint method or the surrounding code to handle the creation of new gradient tensors.

Notes

The issue is specific to the interaction between the autograd engine and DTensor, and any fix would need to take into account the nuances of both components. The fact that FSDP2 does not exhibit the same issue suggests that there may be a way to modify the DTensor implementation to handle this case correctly.

Recommendation

Apply a workaround to modify the DTensor handling in the autograd engine, as this is the most direct way to address the issue. This would involve adding a check for aliasing and allowing the transmutation to Partial when safe, as suggested by @albanD.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Bad interaction with autograd layout invariant and DTensor leaf parameters when grad is transposed [1 pull requests, 1 participants]