pytorch - 💡(How to fix) Fix [distributed][Dynamo] Add get_process_group_all_ranks() for SPMD compilation backends [2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177815Fetched 2026-04-08 01:01:44
View on GitHub
Comments
2
Participants
3
Timeline
75
Reactions
0
Timeline (top)
mentioned ×29subscribed ×29unsubscribed ×12commented ×2

Fix Action

Fix / Workaround

  1. DeviceMesh.get_all_replica_groups(mesh_dim) public method — computes the right answer from _layout and _rank_map, but requires a DeviceMesh reference. At FX pass time, only the group name string (e.g., "33") is available in the graph — the DeviceMesh is gone. There is no reverse mapping from group name back to DeviceMesh.
  2. Attach topology to FX node metadata during Dynamo tracing — DTensor's redistribute path calls torch.ops._c10d_functional.all_reduce directly as an ATen op, bypassing the Python-level funcol.all_reduce wrapper. By the time the op is recorded, the group has already been resolved to a string. Attaching metadata would require changes to how the ATen op is recorded.
  3. Monkey-patch DeviceMesh._init_process_groups — this is our current workaround. Works but fragile across PyTorch versions.

Code Example

replica_groups = [[0,1,2,3], [4,5,6,7]]

---

dist.get_process_group_all_ranks(pg) -> list[list[int]]
# Returns ALL replica groups, same result on every rank
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

I'm working on torch.compile for AWS Neuron (Trainium) and need all replica groups for a process group at compile time. In our backend IR, collectives require explicit replica groups, for instance:

replica_groups = [[0,1,2,3], [4,5,6,7]]

Today, get_process_group_ranks(pg) returns only the calling rank's group — rank 0 gets [0,1,2,3], rank 4 gets [4,5,6,7]. There's no public API to get all sibling groups.

This causes each rank to produce different IR, breaking SPMD compilation. On TP4+FSDP16 (64 devices), this means 64 separate compilations instead of 1 shared one.

The data already exists — DeviceMesh._init_one_process_group computes all groups via pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map) but discards them after PG creation.

Proposed API:

dist.get_process_group_all_ranks(pg) -> list[list[int]]
# Returns ALL replica groups, same result on every rank

Implementation: Add _pg_all_group_ranks dict to _World (mirrors existing _pg_group_ranks), populated in _init_one_process_group, ~24 lines across 2 files. I have a working implementation and can submit a PR.

This benefits any torch.compile backend that lowers collectives to static IR through the more efficient use of caching.

Related: #158793 (DeviceMesh iterations RFC), #159017 (PG bookkeeping)

Alternatives

  1. DeviceMesh.get_all_replica_groups(mesh_dim) public method — computes the right answer from _layout and _rank_map, but requires a DeviceMesh reference. At FX pass time, only the group name string (e.g., "33") is available in the graph — the DeviceMesh is gone. There is no reverse mapping from group name back to DeviceMesh.
  2. Attach topology to FX node metadata during Dynamo tracing — DTensor's redistribute path calls torch.ops._c10d_functional.all_reduce directly as an ATen op, bypassing the Python-level funcol.all_reduce wrapper. By the time the op is recorded, the group has already been resolved to a string. Attaching metadata would require changes to how the ATen op is recorded.
  3. Monkey-patch DeviceMesh._init_process_groups — this is our current workaround. Works but fragile across PyTorch versions.

The _World dict approach is simplest, mirrors existing patterns (_pg_group_ranks), and requires no changes to Dynamo, C++, or the ProcessGroup class.

Additional context

No response

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

Fix Plan

To address the issue, we need to implement the proposed dist.get_process_group_all_ranks(pg) API. This involves adding a _pg_all_group_ranks dictionary to the _World class and populating it in the _init_one_process_group method.

Here are the concrete steps:

  • Add a _pg_all_group_ranks dictionary to the _World class to store all replica groups for each process group.
  • Modify the _init_one_process_group method to populate the _pg_all_group_ranks dictionary.
  • Implement the dist.get_process_group_all_ranks(pg) function to return the list of all replica groups for a given process group.

Example code:

# Add _pg_all_group_ranks dictionary to _World class
class _World:
    def __init__(self):
        # ...
        self._pg_all_group_ranks = {}

# Modify _init_one_process_group method to populate _pg_all_group_ranks
def _init_one_process_group(self, pg, sub_layout, rank_map):
    # ...
    pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map)
    self._pg_all_group_ranks[pg] = pg_ranks_by_dim
    # ...

# Implement dist.get_process_group_all_ranks(pg) function
def get_process_group_all_ranks(pg):
    world = _get_world()
    return world._pg_all_group_ranks.get(pg)

Verification

To verify that the fix worked, you can test the dist.get_process_group_all_ranks(pg) function with different process groups and ranks. The function should return the same list of all replica groups for a given process group, regardless of the calling rank.

Example test code:

pg = dist.new_process_group()
ranks = dist.get_process_group_all_ranks(pg)
print(ranks)  # Should print the same list of replica groups on all ranks

Extra Tips

  • Make sure to update the documentation for the dist.get_process_group_all_ranks(pg) function to reflect its new behavior.
  • Consider adding tests to ensure that the _pg_all_group_ranks dictionary is correctly populated and that the dist.get_process_group_all_ranks(pg) function returns the expected results.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [distributed][Dynamo] Add get_process_group_all_ranks() for SPMD compilation backends [2 comments, 3 participants]