pytorch - 💡(How to fix) Fix Reinplace pass why not remove auto_function_v2 with the trailing copy_

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

class GraphModule(torch.nn.Module): def forward(self, arg0_1: "f32[10240, 10240]"): # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z) ones_like: "f32[10240, 10240]" = torch.ops.aten.ones_like.default(arg0_1, device = device(type='npu'), pin_memory = False)

    # No stacktrace found for following nodes
    sin_default = torch.ops.custom.sin.default(ones_like, arg0_1);  ones_like = sin_default = None
    
     # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
    copy_: "f32[10240, 10240]" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
    return ()

Code Example

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[10240, 10240]"):
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        ones_like: "f32[10240, 10240]" = torch.ops.aten.ones_like.default(arg0_1, device = device(type='npu'), pin_memory = False)
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.custom.sin.default, x = ones_like, _result_base_index = 0, _all_bases = [arg0_1]);  ones_like = None
        getitem_1: "f32[10240, 10240]" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
        copy_: "f32[10240, 10240]" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        return ()`

---

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[10240, 10240]"):
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        ones_like: "f32[10240, 10240]" = torch.ops.aten.ones_like.default(arg0_1, device = device(type='npu'), pin_memory = False)
        
        # No stacktrace found for following nodes
        sin_default = torch.ops.custom.sin.default(ones_like, arg0_1);  ones_like = sin_default = None
        
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        copy_: "f32[10240, 10240]" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
        return ()

---

def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
   
    copy_args_to_copy_nodes = {}
    copy_nodes = {}
    mutated_inputs = OrderedSet[Any]()
    storage_to_nodes = defaultdict(list)
    node_order: dict[Any, int] = {}
    for i, node in enumerate(reversed(graph.nodes)):
        node_order[node] = len(graph.nodes) - i - 1
        storage_to_nodes[get_node_storage(node)].append(node)
        if node.target is aten.copy_.default and node.args[0].op in (
            "placeholder",
            "get_attr",
        ):
            dst = node.args[0]
            src = node.args[1]
            if src.target is operator.getitem and (
                (
                    src.args[0].target == triton_kernel_wrapper_functional
                    and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
                )
                or (src.args[0].target in inplaceable_foreach_ops)
                or (src.args[0].target is torch.ops.higher_order.auto_functionalized)   
            ):    # Missing auto_functionalized_v2
or (src.args[0].target is torch.ops.higher_order.auto_functionalized)
                src = src.args[0]

            copy_args_to_copy_nodes[(dst, src)] = node
            copy_nodes[dst] = node

            mutated_inputs.add(node.args[0])

---

if should_attempt_reinplace and can_inplace(node, mutated_arg):
                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))   # For auto_functionalized_v2, the key for aten.copy_ does not match here.
                if copy_node is not None:
                    replace_dict[copy_node] = copy_node.args[0] 
                if trigger != ReInplaceTrigger.AUTO_FUNC_V2:
                    for user in node.users:
                        if user.target is operator.getitem and user.args[1] == arg:
                            replace_dict[user] = mutated_arg
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Bug

In the reinplace_inplaceable_ops_core pass in torch/_inductor/fx_passes/reinplace.py, the code only handles torch.ops.higher_order.auto_functionalized but forgets auto_functionalized_v2 when unwrapping getitem nodes for aten.copy_.default.

This causes useless self-copy nodes aten.copy_.default(arg0_1, arg0_1) (generated by op decomposition/functionalization) to remain in the graph and not be eliminated.

before decompose:

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[10240, 10240]"):
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        ones_like: "f32[10240, 10240]" = torch.ops.aten.ones_like.default(arg0_1, device = device(type='npu'), pin_memory = False)
        auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.custom.sin.default, x = ones_like, _result_base_index = 0, _all_bases = [arg0_1]);  ones_like = None
        getitem_1: "f32[10240, 10240]" = auto_functionalized_v2[1];  auto_functionalized_v2 = None
        copy_: "f32[10240, 10240]" = torch.ops.aten.copy_.default(arg0_1, getitem_1);  arg0_1 = getitem_1 = copy_ = None
        return ()`

after decompose:

class GraphModule(torch.nn.Module):
    def forward(self, arg0_1: "f32[10240, 10240]"):
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        ones_like: "f32[10240, 10240]" = torch.ops.aten.ones_like.default(arg0_1, device = device(type='npu'), pin_memory = False)
        
        # No stacktrace found for following nodes
        sin_default = torch.ops.custom.sin.default(ones_like, arg0_1);  ones_like = sin_default = None
        
         # File: /home/w00949861/demo/test_auto_function.py/auto.py:31 in forward, code: return torch.ops.custom.sin(torch.ones_like(z,device='npu'),z)
        copy_: "f32[10240, 10240]" = torch.ops.aten.copy_.default(arg0_1, arg0_1);  arg0_1 = copy_ = None
        return ()

As a result, the redundant self-copy node torch.ops.aten.copy_.default(arg0_1, arg0_1) is not eliminated and remains in the graph.

Code Location

https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/reinplace.py

The Problem

The condition checks for auto_functionalized whydoes not include auto_functionalized_v2:

def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
   
    copy_args_to_copy_nodes = {}
    copy_nodes = {}
    mutated_inputs = OrderedSet[Any]()
    storage_to_nodes = defaultdict(list)
    node_order: dict[Any, int] = {}
    for i, node in enumerate(reversed(graph.nodes)):
        node_order[node] = len(graph.nodes) - i - 1
        storage_to_nodes[get_node_storage(node)].append(node)
        if node.target is aten.copy_.default and node.args[0].op in (
            "placeholder",
            "get_attr",
        ):
            dst = node.args[0]
            src = node.args[1]
            if src.target is operator.getitem and (
                (
                    src.args[0].target == triton_kernel_wrapper_functional
                    and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
                )
                or (src.args[0].target in inplaceable_foreach_ops)
                or (src.args[0].target is torch.ops.higher_order.auto_functionalized)   
            ):    # Missing auto_functionalized_v2
or (src.args[0].target is torch.ops.higher_order.auto_functionalized)
                src = src.args[0]

            copy_args_to_copy_nodes[(dst, src)] = node
            copy_nodes[dst] = node

            mutated_inputs.add(node.args[0])

During debugging, I found that the code checks foror (src.args[0].target is torch.ops.higher_order.auto_functionalized)but is missing the corresponding check foror (src.args[0].target is torch.ops.higher_order.auto_functionalized_v2).

 if should_attempt_reinplace and can_inplace(node, mutated_arg):
                copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))   # For auto_functionalized_v2, the key for aten.copy_ does not match here.
                if copy_node is not None:
                    replace_dict[copy_node] = copy_node.args[0] 
                if trigger != ReInplaceTrigger.AUTO_FUNC_V2:
                    for user in node.users:
                        if user.target is operator.getitem and user.args[1] == arg:
                            replace_dict[user] = mutated_arg

The code logic here appears to intend for auto_functionalized_v2 to obtain a non-None value from copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) and eliminate the final self-copy. However, the condition above does not include auto_functionalized_v2. This is my confusion.

Error logs

No response

Versions

asd

cc @bdhirsh @ezyang @chauhang @penguinwu @bobrenjc93 @aorenste

extent analysis

TL;DR

The most likely fix is to add a condition to check for torch.ops.higher_order.auto_functionalized_v2 in the reinplace_inplaceable_ops_core function.

Guidance

  • Add a check for torch.ops.higher_order.auto_functionalized_v2 in the if statement where torch.ops.higher_order.auto_functionalized is checked.
  • Update the copy_args_to_copy_nodes dictionary to include auto_functionalized_v2 nodes.
  • Verify that the copy_node is not None for auto_functionalized_v2 nodes and that the self-copy node is eliminated.
  • Review the code logic to ensure that it correctly handles auto_functionalized_v2 nodes and eliminates redundant self-copy nodes.

Example

if src.args[0].target is torch.ops.higher_order.auto_functionalized or src.args[0].target is torch.ops.higher_order.auto_functionalized_v2:
    src = src.args[0]

Notes

The code change should be made in the reinplace_inplaceable_ops_core function in torch/_inductor/fx_passes/reinplace.py. The fix assumes that the logic for handling auto_functionalized_v2 nodes is similar to that of auto_functionalized nodes.

Recommendation

Apply the workaround by adding the check for torch.ops.higher_order.auto_functionalized_v2 in the reinplace_inplaceable_ops_core function. This should fix the issue of redundant self-copy nodes remaining in the graph.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING