pytorch - 💡(How to fix) Fix make_fx tracing over dist.all_reduce [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179922Fetched 2026-04-11 06:11:50
View on GitHub
Comments
1
Participants
1
Timeline
133
Reactions
0
Participants
Timeline (top)
mentioned ×63subscribed ×63labeled ×6commented ×1

Error Message

NCCL version 2.29.3+cuda12.9 === make_fx graph ===

def forward(self, t_1): clone = torch.ops.aten.clone.default(t_1); t_1 = None _torchbind_obj0 = self._torchbind_obj0 torchbind_obj1 = self.torchbind_obj1 allreduce = torch.ops.c10d.allreduce.default([clone], _torchbind_obj0, torchbind_obj1, None, False); clone = torchbind_obj0 = torchbind_obj1 = None getitem = allreduce[0] getitem_1 = getitem[0]; getitem = None getitem_2 = allreduce[1]; allreduce = getitem_2 = None add = torch.ops.aten.add.Tensor(getitem_1, 1); getitem_1 = None return add

=== regional_inductor === [rank0]: Traceback (most recent call last): [rank0]: File "/data/users/tmanlaibaatar/rigi/tmp_makefx_allreduce_repro.py", line 45, in <module> [rank0]: regional_inductor(gm) [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py", line 263, in regional_inductor [rank0]: gm = copy.deepcopy(gm) [rank0]: ^^^^^^^^^^^^^^^^^ [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy [rank0]: y = copier(memo) [rank0]: ^^^^^^^^^^^^ [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/graph_module.py", line 1018, in deepcopy [rank0]: fake_mod = _CodeOnlyModule(copy.deepcopy(self.dict, memo)) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 136, in deepcopy [rank0]: y = copier(x, memo) [rank0]: ^^^^^^^^^^^^^^^ [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 221, in _deepcopy_dict [rank0]: y[deepcopy(key, memo)] = deepcopy(value, memo) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy [rank0]: y = copier(memo) [rank0]: ^^^^^^^^^^^^ [rank0]: RuntimeError: Tried to deepcopy object torch.torch.classes.c10d.ProcessGroup which does not have a getstate method defined! [rank0]:[W410 08:25:11.508379182 ProcessGroupNCCL.cpp:1648] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Root Cause

I think this is because before we apply regiuonal inductor, we try to deepcopy the graph but ProcessGroup doesn't have proper serialization implementation. When i comment that out, I hit the next error:

^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 251, in aot_stage1_graph_capture
[rank0]:     aot_dispatch_base_graph(
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 296, in aot_dispatch_base_graph
[rank0]:     fw_module = _create_graph(
[rank0]:                 ^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 130, in _create_graph
[rank0]:     fx_g = make_fx(
[rank0]:            ^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2948, in wrapped
[rank0]:     return make_fx_tracer.trace(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2853, in trace
[rank0]:     return self._trace_inner(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2814, in _trace_inner
[rank0]:     t = dispatch_trace(
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1659, in dispatch_trace
[rank0]:     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
[rank0]:     (self.create_arg(fn(*args)),),
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1729, in wrapped
[rank0]:     out = f(*tensors)  # type:ignore[call-arg]
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "<string>", line 1, in <lambda>
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 113, in inner_f
[rank0]:     out, out_descs = call_and_expect_output_descs(f, args)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1230, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 902, in _functionalized_f_helper
[rank0]:     f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 115, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 205, in orig_flat_fn2
[rank0]:     out = orig_flat_fn(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
[rank0]:     out = PropagateUnbackedSymInts(mod).run(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8526, in run_node
[rank0]:     result = super().run_node(n)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 380, in call_function
[rank0]:     return target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1784, in __torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 526, in __torch_dispatch__
[rank0]:     return handle_effects(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 301, in handle_effects
[rank0]:     (new_token, *unwrapped_outs) = with_effects(
[rank0]:                                    ^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 177, in with_effects_proxy
[rank0]:     out = with_effects(token, op, *args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 164, in with_effects_fake
[rank0]:     result = with_effects_dense(token, op, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 142, in with_effects_dense
[rank0]:     out = op(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1463, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2240, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1588, in _cached_dispatch_impl
[rank0]:     entry = cache.get(key, None)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1194, in __eq__
[rank0]:     return isinstance(other, _DispatchCacheKey) and self.key == other.key
[rank0]:                                                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: '__eq__' is not implemented for __torch__.torch.classes.c10d.ProcessGroup

Fix Action

Fix / Workaround

I think this is because before we apply regiuonal inductor, we try to deepcopy the graph but ProcessGroup doesn't have proper serialization implementation. When i comment that out, I hit the next error:

^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 251, in aot_stage1_graph_capture
[rank0]:     aot_dispatch_base_graph(
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 296, in aot_dispatch_base_graph
[rank0]:     fw_module = _create_graph(
[rank0]:                 ^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 130, in _create_graph
[rank0]:     fx_g = make_fx(
[rank0]:            ^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2948, in wrapped
[rank0]:     return make_fx_tracer.trace(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2853, in trace
[rank0]:     return self._trace_inner(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2814, in _trace_inner
[rank0]:     t = dispatch_trace(
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1659, in dispatch_trace
[rank0]:     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
[rank0]:     (self.create_arg(fn(*args)),),
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1729, in wrapped
[rank0]:     out = f(*tensors)  # type:ignore[call-arg]
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "<string>", line 1, in <lambda>
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 113, in inner_f
[rank0]:     out, out_descs = call_and_expect_output_descs(f, args)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1230, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 902, in _functionalized_f_helper
[rank0]:     f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 115, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 205, in orig_flat_fn2
[rank0]:     out = orig_flat_fn(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
[rank0]:     out = PropagateUnbackedSymInts(mod).run(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8526, in run_node
[rank0]:     result = super().run_node(n)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 380, in call_function
[rank0]:     return target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1784, in __torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 526, in __torch_dispatch__
[rank0]:     return handle_effects(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 301, in handle_effects
[rank0]:     (new_token, *unwrapped_outs) = with_effects(
[rank0]:                                    ^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 177, in with_effects_proxy
[rank0]:     out = with_effects(token, op, *args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 164, in with_effects_fake
[rank0]:     result = with_effects_dense(token, op, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 142, in with_effects_dense
[rank0]:     out = op(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1463, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2240, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1588, in _cached_dispatch_impl
[rank0]:     entry = cache.get(key, None)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1194, in __eq__
[rank0]:     return isinstance(other, _DispatchCacheKey) and self.key == other.key
[rank0]:                                                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: '__eq__' is not implemented for __torch__.torch.classes.c10d.ProcessGroup

Code Example

import os

import torch
import torch.distributed as dist
from torch._guards import tracing, TracingContext
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.regional_inductor import regional_inductor


os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("NCCL_SOCKET_IFNAME", "lo")

torch.cuda.set_device("cuda:0")
dist.init_process_group(
    "cpu:gloo,cuda:nccl",
    rank=0,
    world_size=1,
    device_id=torch.device("cuda:0"),
)


def f(t):
    t = t.clone()
    dist.all_reduce(t)
    return t + 1


gm = make_fx(f)(torch.ones(4, device="cuda"))
print("=== make_fx graph ===")
print(gm.code)

for node in gm.graph.nodes:
    if node.op not in ("placeholder", "output"):
        node.meta.setdefault("custom", {})["compile_with_inductor"] = {"inductor_configs": {}}

fake_mode = next(
    node.meta["val"].fake_mode
    for node in gm.graph.nodes
    if node.op == "placeholder" and isinstance(node.meta.get("val"), torch.Tensor)
)

print("=== regional_inductor ===")
with tracing(TracingContext(fake_mode)):
    regional_inductor(gm)

---

NCCL version 2.29.3+cuda12.9
=== make_fx graph ===



def forward(self, t_1):
    clone = torch.ops.aten.clone.default(t_1);  t_1 = None
    _torchbind_obj0 = self._torchbind_obj0
    _torchbind_obj1 = self._torchbind_obj1
    allreduce_ = torch.ops.c10d.allreduce_.default([clone], _torchbind_obj0, _torchbind_obj1, None, False);  clone = _torchbind_obj0 = _torchbind_obj1 = None
    getitem = allreduce_[0]
    getitem_1 = getitem[0];  getitem = None
    getitem_2 = allreduce_[1];  allreduce_ = getitem_2 = None
    add = torch.ops.aten.add.Tensor(getitem_1, 1);  getitem_1 = None
    return add
    
=== regional_inductor ===
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/tmanlaibaatar/rigi/tmp_makefx_allreduce_repro.py", line 45, in <module>
[rank0]:     regional_inductor(gm)
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py", line 263, in regional_inductor
[rank0]:     gm = copy.deepcopy(gm)
[rank0]:          ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy
[rank0]:     y = copier(memo)
[rank0]:         ^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/graph_module.py", line 1018, in __deepcopy__
[rank0]:     fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 136, in deepcopy
[rank0]:     y = copier(x, memo)
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 221, in _deepcopy_dict
[rank0]:     y[deepcopy(key, memo)] = deepcopy(value, memo)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy
[rank0]:     y = copier(memo)
[rank0]:         ^^^^^^^^^^^^
[rank0]: RuntimeError: Tried to deepcopy object __torch__.torch.classes.c10d.ProcessGroup which does not have a __getstate__ method defined!
[rank0]:[W410 08:25:11.508379182 ProcessGroupNCCL.cpp:1648] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

---

^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 251, in aot_stage1_graph_capture
[rank0]:     aot_dispatch_base_graph(
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 296, in aot_dispatch_base_graph
[rank0]:     fw_module = _create_graph(
[rank0]:                 ^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 130, in _create_graph
[rank0]:     fx_g = make_fx(
[rank0]:            ^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2948, in wrapped
[rank0]:     return make_fx_tracer.trace(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2853, in trace
[rank0]:     return self._trace_inner(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2814, in _trace_inner
[rank0]:     t = dispatch_trace(
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1659, in dispatch_trace
[rank0]:     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
[rank0]:     (self.create_arg(fn(*args)),),
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1729, in wrapped
[rank0]:     out = f(*tensors)  # type:ignore[call-arg]
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "<string>", line 1, in <lambda>
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 113, in inner_f
[rank0]:     out, out_descs = call_and_expect_output_descs(f, args)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1230, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 902, in _functionalized_f_helper
[rank0]:     f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 115, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 205, in orig_flat_fn2
[rank0]:     out = orig_flat_fn(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
[rank0]:     out = PropagateUnbackedSymInts(mod).run(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8526, in run_node
[rank0]:     result = super().run_node(n)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 380, in call_function
[rank0]:     return target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1784, in __torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 526, in __torch_dispatch__
[rank0]:     return handle_effects(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 301, in handle_effects
[rank0]:     (new_token, *unwrapped_outs) = with_effects(
[rank0]:                                    ^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 177, in with_effects_proxy
[rank0]:     out = with_effects(token, op, *args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 164, in with_effects_fake
[rank0]:     result = with_effects_dense(token, op, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 142, in with_effects_dense
[rank0]:     out = op(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1463, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2240, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1588, in _cached_dispatch_impl
[rank0]:     entry = cache.get(key, None)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1194, in __eq__
[rank0]:     return isinstance(other, _DispatchCacheKey) and self.key == other.key
[rank0]:                                                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: '__eq__' is not implemented for __torch__.torch.classes.c10d.ProcessGroup

---

def f(t):
      t = t.clone()
      dist.all_reduce(t)
      return t + 1

  # DYNAMO VERSION

  def forward(self, L_t_ : torch.Tensor):
      l_t_ = L_t_
      t = l_t_.clone();  l_t_ = None
      tensor = torch.ops._c10d_functional.all_reduce(t, 'sum', '0')
      wait_tensor = torch.ops._c10d_functional.wait_tensor(tensor);  tensor =
  None
      copy_ = t.copy_(wait_tensor);  wait_tensor = copy_ = None
      add = t + 1;  t = None
      return (add,)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

import os

import torch
import torch.distributed as dist
from torch._guards import tracing, TracingContext
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.regional_inductor import regional_inductor


os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("NCCL_SOCKET_IFNAME", "lo")

torch.cuda.set_device("cuda:0")
dist.init_process_group(
    "cpu:gloo,cuda:nccl",
    rank=0,
    world_size=1,
    device_id=torch.device("cuda:0"),
)


def f(t):
    t = t.clone()
    dist.all_reduce(t)
    return t + 1


gm = make_fx(f)(torch.ones(4, device="cuda"))
print("=== make_fx graph ===")
print(gm.code)

for node in gm.graph.nodes:
    if node.op not in ("placeholder", "output"):
        node.meta.setdefault("custom", {})["compile_with_inductor"] = {"inductor_configs": {}}

fake_mode = next(
    node.meta["val"].fake_mode
    for node in gm.graph.nodes
    if node.op == "placeholder" and isinstance(node.meta.get("val"), torch.Tensor)
)

print("=== regional_inductor ===")
with tracing(TracingContext(fake_mode)):
    regional_inductor(gm)

This crashes with:

NCCL version 2.29.3+cuda12.9
=== make_fx graph ===



def forward(self, t_1):
    clone = torch.ops.aten.clone.default(t_1);  t_1 = None
    _torchbind_obj0 = self._torchbind_obj0
    _torchbind_obj1 = self._torchbind_obj1
    allreduce_ = torch.ops.c10d.allreduce_.default([clone], _torchbind_obj0, _torchbind_obj1, None, False);  clone = _torchbind_obj0 = _torchbind_obj1 = None
    getitem = allreduce_[0]
    getitem_1 = getitem[0];  getitem = None
    getitem_2 = allreduce_[1];  allreduce_ = getitem_2 = None
    add = torch.ops.aten.add.Tensor(getitem_1, 1);  getitem_1 = None
    return add
    
=== regional_inductor ===
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/users/tmanlaibaatar/rigi/tmp_makefx_allreduce_repro.py", line 45, in <module>
[rank0]:     regional_inductor(gm)
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/passes/regional_inductor.py", line 263, in regional_inductor
[rank0]:     gm = copy.deepcopy(gm)
[rank0]:          ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy
[rank0]:     y = copier(memo)
[rank0]:         ^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/graph_module.py", line 1018, in __deepcopy__
[rank0]:     fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo))
[rank0]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 136, in deepcopy
[rank0]:     y = copier(x, memo)
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 221, in _deepcopy_dict
[rank0]:     y[deepcopy(key, memo)] = deepcopy(value, memo)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/copy.py", line 143, in deepcopy
[rank0]:     y = copier(memo)
[rank0]:         ^^^^^^^^^^^^
[rank0]: RuntimeError: Tried to deepcopy object __torch__.torch.classes.c10d.ProcessGroup which does not have a __getstate__ method defined!
[rank0]:[W410 08:25:11.508379182 ProcessGroupNCCL.cpp:1648] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

I think this is because before we apply regiuonal inductor, we try to deepcopy the graph but ProcessGroup doesn't have proper serialization implementation. When i comment that out, I hit the next error:

^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 251, in aot_stage1_graph_capture
[rank0]:     aot_dispatch_base_graph(
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 296, in aot_dispatch_base_graph
[rank0]:     fw_module = _create_graph(
[rank0]:                 ^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 130, in _create_graph
[rank0]:     fx_g = make_fx(
[rank0]:            ^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2948, in wrapped
[rank0]:     return make_fx_tracer.trace(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2853, in trace
[rank0]:     return self._trace_inner(f, *args)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2814, in _trace_inner
[rank0]:     t = dispatch_trace(
[rank0]:         ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1659, in dispatch_trace
[rank0]:     graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 890, in trace
[rank0]:     (self.create_arg(fn(*args)),),
[rank0]:                      ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1729, in wrapped
[rank0]:     out = f(*tensors)  # type:ignore[call-arg]
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "<string>", line 1, in <lambda>
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture.py", line 113, in inner_f
[rank0]:     out, out_descs = call_and_expect_output_descs(f, args)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1230, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 902, in _functionalized_f_helper
[rank0]:     f_outs, f_outs_descs = call_and_expect_output_descs(fn, f_args)
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 115, in inner_fn
[rank0]:     outs, outs_descs = call_and_expect_output_descs(fn, args)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 773, in call_and_expect_output_descs
[rank0]:     outs_pair = fn(*args)
[rank0]:                 ^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_compile.py", line 205, in orig_flat_fn2
[rank0]:     out = orig_flat_fn(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1534, in functional_call
[rank0]:     out = PropagateUnbackedSymInts(mod).run(*args)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 8526, in run_node
[rank0]:     result = super().run_node(n)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/interpreter.py", line 380, in call_function
[rank0]:     return target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1784, in __torch_function__
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 526, in __torch_dispatch__
[rank0]:     return handle_effects(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 301, in handle_effects
[rank0]:     (new_token, *unwrapped_outs) = with_effects(
[rank0]:                                    ^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 177, in with_effects_proxy
[rank0]:     out = with_effects(token, op, *args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 103, in __call__
[rank0]:     return super().__call__(token, op, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 539, in __call__
[rank0]:     return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 422, in dispatch
[rank0]:     result = handler(mode, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 164, in with_effects_fake
[rank0]:     result = with_effects_dense(token, op, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_higher_order_ops/effects.py", line 142, in with_effects_dense
[rank0]:     out = op(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 1091, in __call__
[rank0]:     return self._op(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_ops.py", line 959, in handler
[rank0]:     return torch._library.utils.handle_dispatch_mode(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_library/utils.py", line 325, in handle_dispatch_mode
[rank0]:     return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1277, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1463, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2240, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1588, in _cached_dispatch_impl
[rank0]:     entry = cache.get(key, None)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tmanlaibaatar/.conda/envs/rigi/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1194, in __eq__
[rank0]:     return isinstance(other, _DispatchCacheKey) and self.key == other.key
[rank0]:                                                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]: NotImplementedError: '__eq__' is not implemented for __torch__.torch.classes.c10d.ProcessGroup

I don't know if it is good idea to implement these methods or not. In Dynamo, this gets rewritten into

  def f(t):
      t = t.clone()
      dist.all_reduce(t)
      return t + 1

  # DYNAMO VERSION

  def forward(self, L_t_ : torch.Tensor):
      l_t_ = L_t_
      t = l_t_.clone();  l_t_ = None
      tensor = torch.ops._c10d_functional.all_reduce(t, 'sum', '0')
      wait_tensor = torch.ops._c10d_functional.wait_tensor(tensor);  tensor =
  None
      copy_ = t.copy_(wait_tensor);  wait_tensor = copy_ = None
      add = t + 1;  t = None
      return (add,)

Versions

main

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @ezyang @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The issue is likely due to the ProcessGroup object not being properly serialized when trying to deepcopy the graph, causing a NotImplementedError when trying to compare two ProcessGroup objects.

Guidance

  • The error occurs because ProcessGroup does not have a __getstate__ method defined, which is required for serialization.
  • To fix this, you could try to implement the __getstate__ and __setstate__ methods for the ProcessGroup class, but this might not be feasible or desirable.
  • Alternatively, you could try to avoid deep copying the graph by modifying the regional_inductor function to not require a deepcopy of the graph.
  • You could also try to use a different serialization method, such as using torch.save and torch.load to serialize and deserialize the graph.

Example

# Example of how to implement __getstate__ and __setstate__ methods
class ProcessGroup:
    def __getstate__(self):
        # Return a dictionary of the object's state
        return {'state': self.state}

    def __setstate__(self, state):
        # Set the object's state from the dictionary
        self.state = state

Note that this is just an example and the actual implementation would depend on the specifics of the ProcessGroup class.

Notes

  • The NotImplementedError is raised because the __eq__ method is not implemented for the ProcessGroup class.
  • The regional_inductor function is trying to deepcopy the graph, which is causing the error.
  • The Dynamo version of the code is rewriting the f function to use torch.ops._c10d_functional.all_reduce instead of dist.all_reduce, which might be a possible workaround.

Recommendation

Apply a workaround by modifying the regional_inductor function to not require a deepcopy of the graph, or use a different serialization method. Implementing the __getstate__ and __setstate__ methods for the ProcessGroup class might not be feasible or desirable.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING