pytorch - ✅(Solved) Fix [DTensor] reshape is dispatched to aten.view instead of aten.reshape [1 pull requests, 9 comments, 6 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178616Fetched 2026-04-08 01:40:30
View on GitHub
Comments
9
Participants
6
Timeline
130
Reactions
0
Author
Timeline (top)
mentioned ×58subscribed ×58commented ×9labeled ×3

Root Cause

DTensor.reshape() dispatches to aten.view.default for sharded DTensors because is_contiguous() returns True based on global stride metadata, without considering the sharding placement

Fix Action

Fix / Workaround

DTensor.reshape() dispatches to aten.view.default for sharded DTensors because is_contiguous() returns True based on global stride metadata, without considering the sharding placement

repro

"""torchrun --nproc_per_node=4 demo_reshape_dispatch.py"""
import torch
import torch.distributed as dist
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh
from torch.utils._python_dispatch import TorchDispatchMode

class DispatchTracer(TorchDispatchMode):
    def __init__(self):
        self.ops = []
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.ops.append(str(func))
        return func(*args, **(kwargs or {}))

PR fix notes

PR #178573: [DTensor] support reshape with _StridedShard

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178573

Changed files

  • test/distributed/tensor/test_view_ops.py (modified, +180/-6)
  • torch/distributed/tensor/_ops/_view_ops.py (modified, +10/-12)

Code Example

"""torchrun --nproc_per_node=4 demo_reshape_dispatch.py"""
import torch
import torch.distributed as dist
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh
from torch.utils._python_dispatch import TorchDispatchMode

class DispatchTracer(TorchDispatchMode):
    def __init__(self):
        self.ops = []
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.ops.append(str(func))
        return func(*args, **(kwargs or {}))

dist.init_process_group("gloo")
mesh = init_device_mesh("cpu", (4,))
dt = distribute_tensor(torch.arange(24).view(4, 6), mesh, [Shard(0)])
tracer = DispatchTracer()
with tracer:
    dt.reshape(24)
if dist.get_rank() == 0:
    assert any("reshape" in op.lower() for op in tracer.ops), (
        f"Expected aten.reshape.default but got: {[op for op in tracer.ops if 'view' in op.lower() or 'reshape' in op.lower()]}"
    )
dist.destroy_process_group()
RAW_BUFFERClick to expand / collapse

DTensor.reshape() dispatches to aten.view.default for sharded DTensors because is_contiguous() returns True based on global stride metadata, without considering the sharding placement

This matters if we want to trigger redistribution using strict_view=False https://github.com/pytorch/pytorch/pull/178573

I added assertion in following repro: Expected aten.reshape.default but got: ['aten.view.default']

repro

"""torchrun --nproc_per_node=4 demo_reshape_dispatch.py"""
import torch
import torch.distributed as dist
from torch.distributed.tensor import Shard, distribute_tensor, init_device_mesh
from torch.utils._python_dispatch import TorchDispatchMode

class DispatchTracer(TorchDispatchMode):
    def __init__(self):
        self.ops = []
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        self.ops.append(str(func))
        return func(*args, **(kwargs or {}))

dist.init_process_group("gloo")
mesh = init_device_mesh("cpu", (4,))
dt = distribute_tensor(torch.arange(24).view(4, 6), mesh, [Shard(0)])
tracer = DispatchTracer()
with tracer:
    dt.reshape(24)
if dist.get_rank() == 0:
    assert any("reshape" in op.lower() for op in tracer.ops), (
        f"Expected aten.reshape.default but got: {[op for op in tracer.ops if 'view' in op.lower() or 'reshape' in op.lower()]}"
    )
dist.destroy_process_group()

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

Fix Plan

To fix the issue with DTensor.reshape() dispatching to aten.view.default for sharded DTensors, we need to modify the is_contiguous() method to consider the sharding placement.

Code Changes

We can achieve this by adding a check for sharding placement in the is_contiguous() method:

def is_contiguous(self):
    # Check if the tensor is sharded
    if self.is_sharded:
        # Check if the sharding placement is contiguous
        if not self.sharding_placement.is_contiguous():
            return False
    # Existing contiguous check logic
    return self.stride() == self.size()

Additionally, we need to modify the reshape() method to trigger redistribution using strict_view=False when the tensor is sharded:

def reshape(self, shape):
    if self.is_sharded:
        # Trigger redistribution using strict_view=False
        return self.view(shape, strict_view=False)
    # Existing reshape logic
    return self.view(shape)

Verification

To verify the fix, we can run the provided repro script with the modified is_contiguous() and reshape() methods. The assertion should pass, indicating that DTensor.reshape() now correctly dispatches to aten.reshape.default for sharded DTensors.

Example Use Case

dt = distribute_tensor(torch.arange(24).view(4, 6), mesh, [Shard(0)])
dt.reshape(24)  # Should trigger redistribution using strict_view=False

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [DTensor] reshape is dispatched to aten.view instead of aten.reshape [1 pull requests, 9 comments, 6 participants]