pytorch - ✅(Solved) Fix Dynamo regressed support for functional collective (all_gather_tensor) [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179722Fetched 2026-04-09 07:50:19
View on GitHub
Comments
1
Participants
1
Timeline
40
Reactions
0
Author
Participants
Timeline (top)
mentioned ×18subscribed ×18labeled ×3commented ×1

PR fix notes

PR #180389: [dynamo] Allow tracing into _maybe_view_chunk_cat for functional collectives

Description (problem / solution / changelog)

Summary

torch.compile(fullgraph=True) fails when all_gather_tensor is called with gather_dim != 0, because Dynamo skips tracing into torch._utils._maybe_view_chunk_cat.

Root Cause

torch._utils is on Dynamo's MOD_SKIPLIST, so functions in that module are not traced by default. The function _chunk_or_narrow_cat was already explicitly allowed in manual_torch_name_rule_map in trace_rules.py, but _maybe_view_chunk_cat — which was added later and calls _chunk_or_narrow_cat — was never added to the allowlist.

When all_gather_tensor is called with gather_dim != 0, it invokes _maybe_view_chunk_cat, which Dynamo refuses to trace, producing:

Unsupported: Attempted to call function marked as skipped
  module: torch._utils, qualname: _maybe_view_chunk_cat, skip reason: file matches MOD_SKIPLIST

Fix

One-line addition in torch/_dynamo/trace_rules.py to add _maybe_view_chunk_cat to manual_torch_name_rule_map as a UserFunctionVariable, right after the existing _chunk_or_narrow_cat entry. This tells Dynamo to inline the function (trace into it as user code) rather than skipping it.

     "torch._utils._chunk_or_narrow_cat": UserFunctionVariable,
+    "torch._utils._maybe_view_chunk_cat": UserFunctionVariable,

Tests Added

Added 3 tests to test/dynamo/test_fake_distributed.py in the TestFakeDistributed class, using the FakeProcessGroup (backend="fake", world_size=2) to test all_gather_tensor under torch.compile(fullgraph=True) without requiring GPUs:

  1. test_all_gather_tensor_gather_dim_0gather_dim=0 (no _maybe_view_chunk_cat call, baseline). Passes with and without the fix.
  2. test_all_gather_tensor_gather_dim_1_view_pathgather_dim=1 with shape [1, 4], triggers the view optimization path in _maybe_view_chunk_cat. Fails without the fix, passes with it.
  3. test_all_gather_tensor_gather_dim_2_chunk_cat_pathgather_dim=2 with shape [1, 3, 4], triggers the chunk+cat fallback path in _maybe_view_chunk_cat (dim 1 has size 3, preventing view optimization). Fails without the fix, passes with it.

Each test compares the compiled output against eager execution to verify correctness.

Repro and Test Results

<details> <summary>Repro Script</summary>
import os
import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import all_gather_tensor

os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

dist.init_process_group(backend="gloo")

@torch.compile(fullgraph=True)
def my_fn(t):
    # gather_dim != 0 triggers _maybe_view_chunk_cat
    return all_gather_tensor(t, gather_dim=1, group=dist.group.WORLD)

t = torch.randn(1, 4)
try:
    result = my_fn(t)
    print(f"SUCCESS: result shape={result.shape}")
except Exception as e:
    print(f"FAILED: {type(e).__name__}: {e}")
finally:
    dist.destroy_process_group()

Before fix:

FAILED: Unsupported: Attempted to call function marked as skipped
  ...module: torch._utils, qualname: _maybe_view_chunk_cat, skip reason: file matches MOD_SKIPLIST

After fix:

SUCCESS: result shape=torch.Size([1, 4])
</details> <details> <summary>Full Test Suite Results</summary>
test/dynamo/test_fake_distributed.py (13 tests, all passing):

test_all_gather_tensor_gather_dim_0                    PASSED
test_all_gather_tensor_gather_dim_1_view_path          PASSED
test_all_gather_tensor_gather_dim_2_chunk_cat_path     PASSED
test_all_to_all_single_autograd                        PASSED
test_device_mesh_flatten                               PASSED
test_device_mesh_get_local_rank                        PASSED
test_device_mesh_init_skip_after_graph_break           PASSED
test_compiled_batch_isend_irecv_mixed_graph            PASSED
test_compiled_fire_and_forget_isend_graph              PASSED
test_compiled_irecv_graph                              PASSED
test_compiled_isend_graph                              PASSED
test_compiled_p2p_interleave_graph                     PASSED
test_mutating_p2p_op_graph_breaks                      PASSED
</details>

Fixes #179722

<details> <summary>Repro Script</summary>
import os
import torch
import torch.distributed as dist
from torch.distributed._functional_collectives import all_gather_tensor

os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")

dist.init_process_group(backend="gloo")

@torch.compile(fullgraph=True)
def my_fn(t):
    # gather_dim != 0 triggers _maybe_view_chunk_cat
    return all_gather_tensor(t, gather_dim=1, group=dist.group.WORLD)

t = torch.randn(1, 4)
try:
    result = my_fn(t)
    print(f"SUCCESS: result shape={result.shape}")
except Exception as e:
    print(f"FAILED: {type(e).__name__}: {e}")
finally:
    dist.destroy_process_group()
</details> <details> <summary>Agent Worklog</summary>

Run 1

Reproduced the issue

Ran the repro script which uses all_gather_tensor with gather_dim=1 under torch.compile(fullgraph=True). Confirmed failure:

FAILED: Unsupported: Attempted to call function marked as skipped
  module: torch._utils, qualname: _maybe_view_chunk_cat, skip reason: file matches MOD_SKIPLIST

Root cause identified

torch._utils is on Dynamo's MOD_SKIPLIST. The function _chunk_or_narrow_cat was already allowed in manual_torch_name_rule_map at trace_rules.py:182, but _maybe_view_chunk_cat (which calls _chunk_or_narrow_cat) was not added when it was introduced.

Fix applied

Added "torch._utils._maybe_view_chunk_cat": UserFunctionVariable to manual_torch_name_rule_map in trace_rules.py:183, right after the existing _chunk_or_narrow_cat entry.

Verification

  • Repro script: SUCCESS (result shape=torch.Size([1, 4]))
  • Edge cases: 3/3 passed (gather_dim=0, gather_dim=1, gather_dim=2)
  • Lint: clean (spin fixlint passed)
  • Diff: 1 line added in torch/_dynamo/trace_rules.py

Run 2

User: Add dynamo tests for this fix. Add test cases to the existing distributed dynamo test suite that verify all_gather_tensor works under torch.compile(fullgraph=True) with gather_dim=0 and gather_dim!=0.

Added dynamo tests

Added 3 tests to test/dynamo/test_fake_distributed.py::TestFakeDistributed:

  • test_all_gather_tensor_gather_dim_0 — gather_dim=0 (no _maybe_view_chunk_cat call)
  • test_all_gather_tensor_gather_dim_1_view_path — gather_dim=1, view optimization path
  • test_all_gather_tensor_gather_dim_2_chunk_cat_path — gather_dim=2, chunk+cat fallback path

Used FakeProcessGroup (backend="fake", world_size=2) so tests run single-process without GPUs. Each test compiles with fullgraph=True, backend="eager" and compares against eager execution.

Verified tests catch the bug

  • With fix reverted: gather_dim=0 passes (doesn't hit _maybe_view_chunk_cat), gather_dim=1 and gather_dim=2 both fail with Unsupported: Attempted to call function marked as skipped.
  • With fix applied: all 3 pass.

Full suite

All 13 tests in test_fake_distributed.py pass. Lint clean.

</details>

This PR was generated by ptq with human review.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98 @mlazos

Changed files

  • test/distributed/tensor/parallel/test_micro_pipeline_tp.py (modified, +0/-6)
  • test/dynamo/test_fake_distributed.py (modified, +45/-0)
  • torch/_dynamo/trace_rules.py (modified, +1/-0)
RAW_BUFFERClick to expand / collapse

See https://github.com/pytorch/pytorch/pull/169404#discussion_r3047094669

Filing this issue for visibility.

This issue needs to be triaged

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @Lucaskabela @jataylo @azahed98

extent analysis

TL;DR

  • The issue requires triage to determine the root cause and appropriate fix.

Guidance

  • Review the discussion at https://github.com/pytorch/pytorch/pull/169404#discussion_r3047094669 to understand the context of the issue.
  • Assign the issue to a team member for triage and further investigation.
  • Check if the issue is related to a specific PyTorch version or functionality.
  • Consider creating a minimal reproducible example to help with debugging.

Notes

  • The issue lacks specific technical details, making it difficult to provide a precise fix or workaround.
  • Further investigation and triage are necessary to determine the root cause and appropriate solution.

Recommendation

  • Apply workaround: Assign the issue to a team member for triage and further investigation, as the issue lacks enough information to suggest a specific fix or upgrade.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING