pytorch - ✅(Solved) Fix [DTensor] sharded view incorrectly passes when redistribution is needed [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179502Fetched 2026-04-08 03:00:43
View on GitHub
Comments
1
Participants
2
Timeline
65
Reactions
1
Author
Timeline (top)
mentioned ×27subscribed ×27labeled ×7commented ×1

Fix Action

Fixed

PR fix notes

PR #179509: [DTensor] Fix Split(Flatten) sharding propagation for non-first flatten dims

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #179509

fix: https://github.com/pytorch/pytorch/issues/179502

we had silent wrong results for Split(Flatten(...)) rules. After fix,

  • dt[4, 4, 4] → dt.view(8, 8) with [Shard(1), Replicate()] on mesh (2, 3): returns (_StridedShard(dim=0, sf=4), Replicate()) with zero communication
  • dt[12, 8] → dt.view(16, 6) with Shard(1) on mesh (6,): raises RuntimeError("output dimension 0 (size 16) is not evenly divisible by mesh dimension 0 (size 6)")

fix

  • _rewrite_plain_shard: previously only handled sharding on the first dim of a flatten group. Now computes the
    strided shard factor from the sharded dim's position within the flatten, then finds which output dim can carry that stride. Also adds k >= mesh_size and k % mesh_size == 0 checks so we never produce a _StridedShard whose chunk count doesn't evenly divide the dim size.
  • _analyze_split / analyze(): Split(Flatten(...)) has multiple input dims, but we were only tracking the first one. Now we track all of them, so sharding on any dim in the flatten group works.

Changed files

  • test/distributed/tensor/test_view_ops.py (modified, +78/-0)
  • torch/distributed/tensor/_ops/_view_ops.py (modified, +138/-19)

Code Example

"""
Repro: DTensor view on a dim-1 sharded tensor gives wrong numerics without redistribution.

Run with: torchrun --nproc_per_node=4 agent_space/repro_view_shard.py
"""

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh


def main():
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    device_mesh = init_device_mesh("cpu", (4,))

    torch.manual_seed(0)
    tensor = torch.randn(12, 8)

    dtensor = distribute_tensor(tensor, device_mesh, [Shard(1)])

    # Reference: view on the full (global) tensor
    ref = tensor.view(-1, 6)

    # DTensor path: view on the sharded tensor, then collect
    out = dtensor.view(-1, 6).full_tensor()

    if rank == 0:
        match = torch.allclose(ref, out)
        print(f"Shapes: ref={ref.shape}, out={out.shape}")
        print(f"Match: {match}")
        if not match:
            print(f"Max diff: {(ref - out).abs().max().item()}")
            print(f"\nRef:\n{ref}")
            print(f"\nOut:\n{out}")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

This script performs a view over 4 GPUS: x: dt[12, 8]: S(1); x.view(16, 6) gives a S(0) result, when it requires redistribution. As a result, full-tensor numerics are incorrect.

"""
Repro: DTensor view on a dim-1 sharded tensor gives wrong numerics without redistribution.

Run with: torchrun --nproc_per_node=4 agent_space/repro_view_shard.py
"""

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, distribute_tensor, init_device_mesh


def main():
    dist.init_process_group(backend="gloo")
    rank = dist.get_rank()
    device_mesh = init_device_mesh("cpu", (4,))

    torch.manual_seed(0)
    tensor = torch.randn(12, 8)

    dtensor = distribute_tensor(tensor, device_mesh, [Shard(1)])

    # Reference: view on the full (global) tensor
    ref = tensor.view(-1, 6)

    # DTensor path: view on the sharded tensor, then collect
    out = dtensor.view(-1, 6).full_tensor()

    if rank == 0:
        match = torch.allclose(ref, out)
        print(f"Shapes: ref={ref.shape}, out={out.shape}")
        print(f"Match: {match}")
        if not match:
            print(f"Max diff: {(ref - out).abs().max().item()}")
            print(f"\nRef:\n{ref}")
            print(f"\nOut:\n{out}")

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Versions

nightly 4/6

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

TL;DR

The issue can be resolved by ensuring proper redistribution of the tensor after viewing, as the current implementation does not handle this correctly.

Guidance

  • The problem arises from the view operation on a sharded tensor, which requires redistribution to maintain correct numerics.
  • To verify the issue, compare the results of the reference tensor view and the DTensor view using torch.allclose(ref, out).
  • The dtensor.view(-1, 6).full_tensor() operation may need to be modified to include redistribution, potentially using a combination of distribute_tensor and Shard to ensure correct handling of the sharded tensor.
  • Investigate the use of dtensor.redistribute or similar methods to properly redistribute the tensor after the view operation.

Example

# Example of how redistribution might be applied
dtensor = distribute_tensor(tensor, device_mesh, [Shard(1)])
out = dtensor.view(-1, 6)
out = distribute_tensor(out, device_mesh, [Shard(1)]).full_tensor()

Note: This example is speculative and may not be the exact solution, as the correct implementation details are not provided in the issue.

Notes

The provided code snippet and issue description suggest that the problem is related to the handling of sharded tensors and redistribution, but the exact solution may depend on the specifics of the DTensor and distribute_tensor implementations.

Recommendation

Apply workaround: Modify the dtensor.view operation to include proper redistribution, potentially using a combination of distribute_tensor and Shard, to ensure correct handling of the sharded tensor.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [DTensor] sharded view incorrectly passes when redistribution is needed [1 pull requests, 1 comments, 2 participants]