pytorch - 💡(How to fix) Fix `_LocalDeviceMesh` coordinate cache unreliable leading to failures [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fixed

Code Example

mesh_dim0_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1}) mesh_dim1_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1})  # Wrong
mesh_dim0_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1}) mesh_dim1_local_rank=LocalIntNode({0: 0, 1: 1, 2: 0, 3: 1})  # Expected

---

class Foo:
    def get_id(self):
        return id(self)

def make(i):  # similar to DeviceMesh.__getitem__
    x = Foo()
    print("created", id(x))
    return x

prev_id = None
for i in range(20):
    id_x = make(i).get_id()
    if id_x == prev_id:
        raise RuntimeError(f"Dup at {i}: {id_x}")
    prev_id = id_x
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

I noticed the issue while running python distributed/tensor/test_utils.py TestStridedShardingWithLocalTensor.test_2d_mesh_2d_tensor_strided_sharding

The issue is specifically at: https://github.com/pytorch/pytorch/blob/bec164b201e9dd53069ad437e4e5d985140a2454/test/distributed/tensor/test_utils.py#L949-L950

The value of mesh_dim1_local_rank is sometimes wrongly the same as mesh_dim0_local_rank. I put a print there to verify:

mesh_dim0_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1}) mesh_dim1_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1})  # Wrong
mesh_dim0_local_rank=LocalIntNode({0: 0, 1: 0, 2: 1, 3: 1}) mesh_dim1_local_rank=LocalIntNode({0: 0, 1: 1, 2: 0, 3: 1})  # Expected

After some debugging I traced the issue to https://github.com/pytorch/pytorch/blob/bec164b201e9dd53069ad437e4e5d985140a2454/torch/distributed/_local_tensor/__init__.py#L1598-L1600

What happened is that the temporary object returned by mesh_2d["dim0"] get destroyed before mesh_2d["dim1"] created its result object so that when that creates the DeviceMesh it reuses the slot of the just destroyed object so it has the same id and the cache wrongly returns the same coordinates instead of calculating them for the new dimension.

This is much more likely to happen in Python 3.13 which is aggressively optimized for reuse.

But the code has this issue in general: Unless the mesh is not destroyed for the entire lifetime of the program or a single mesh is ever used it will eventually run into this problem. The array-like syntax mesh_2d["dim0"] makes it very easy to run into this as it is not clear what is happening in the background of this access, i.e. that a new object gets created. If the parent mesh would cache the returned sub-meshes the issue would be less obvious but still possible.

For demonstration: This code show the reuse of objects in general:

class Foo:
    def get_id(self):
        return id(self)

def make(i):  # similar to DeviceMesh.__getitem__
    x = Foo()
    print("created", id(x))
    return x

prev_id = None
for i in range(20):
    id_x = make(i).get_id()
    if id_x == prev_id:
        raise RuntimeError(f"Dup at {i}: {id_x}")
    prev_id = id_x

Versions

PyTorch version: 2.12.0+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Rocky Linux 9.6 (Blue Onyx) (x86_64) GCC version: (GCC) 13.3.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.34

Python version: 3.12.3 (main, May 13 2025, 17:56:01) [GCC 13.3.0] (64-bit runtime) Python platform: Linux-5.14.0-570.49.1.el9_6.x86_64-x86_64-with-glibc2.34

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `_LocalDeviceMesh` coordinate cache unreliable leading to failures [1 pull requests]