pytorch - ✅(Solved) Fix torch._inductor.pattern_matcher.replace_with_graph leaks recompute tags past nested replacement args [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178374Fetched 2026-04-08 01:25:57
View on GitHub
Comments
0
Participants
1
Timeline
14
Reactions
0
Participants
Timeline (top)
labeled ×7cross-referenced ×3referenced ×2mentioned ×1

Error Message

Error logs

Traceback (most recent call last):

Root Cause

This appears to happen because nested replacement args are used directly instead of flattening pytree leaves first before constructing the stop-node set.

Fix Action

Fixed

PR fix notes

PR #178375: [inductor] Fix recompute tag propagation for nested replacement args

Description (problem / solution / changelog)

Summary

Fix ReplacementPatternEntry.replace_with_graph() so recompute-related tags do not leak past nested replacement arguments.

This change flattens nested replacement args before building the stop_nodes set used by tag percolation, and adds a regression test covering nested tuple inputs.

Fixes #178374

Test Plan

  • Added a regression test in test/inductor/test_pattern_matcher.py
  • Ran:
    • python test/inductor/test_pattern_matcher.py -k nested_replacement_args_do_not_percolate_tags

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_pattern_matcher.py (modified, +48/-0)
  • torch/_inductor/pattern_matcher.py (modified, +2/-1)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch._inductor.pattern_matcher.Match.replace_with_graph() appears to mishandle nested replacement args when computing the boundary where recompute-related tags should stop propagating.

Expected behavior:

  • recompute / ac_graph_id metadata should propagate into newly inserted replacement nodes
  • propagation should stop at the replacement inputs

Actual behavior:

  • when replacement args are nested, the stop boundary is computed incorrectly
  • recompute metadata leaks into input nodes outside the replacement graph

This appears to happen because nested replacement args are used directly instead of flattening pytree leaves first before constructing the stop-node set.

This seems related to pattern matcher replacement-boundary handling, but appears distinct from prior replacement-graph assumption / ac_graph_id issues.

Related context:

  • #134363
  • #129684
  • #162019

Minimal reproducer

import torch
import torch.fx
from torch._inductor.pattern_matcher import Arg, Match, MatchContext


def test_nested_boundary_stopping_logic():
    class M(torch.nn.Module):
        def forward(self, x, y):
            s1 = torch.sin(x)
            s2 = torch.sin(y)
            c = torch.cat([s1, s2], dim=1)
            return torch.cos(c)

    traced = torch.fx.symbolic_trace(M())
    graph = traced.graph
    nodes_before = set(graph.nodes)

    call_nodes = [n for n in graph.nodes if n.op == "call_function"]
    sin_nodes = [n for n in call_nodes if n.target == torch.sin]
    assert len(sin_nodes) == 2
    node_s1, node_s2 = sin_nodes
    node_cat = next(n for n in call_nodes if n.target == torch.cat)

    node_cat.meta["recompute"] = "1"
    node_cat.meta["ac_graph_id"] = 1

    def rep_fn(nested):
        x = nested[0]
        y = nested[1]
        return torch.add(x, y)

    replacement_gm = torch.fx.symbolic_trace(rep_fn)

    pattern = Arg()
    ctx = MatchContext([pattern], graph=graph)
    ctx.pattern_to_node[pattern] = node_cat
    match = Match(ctx, pattern)

    match.replace_with_graph(replacement_gm, args=[(node_s1, node_s2)])

    nodes_after = [n for n in graph.nodes if n not in nodes_before and n.op != "output"]
    new_add_nodes = [
        n for n in nodes_after if n.op == "call_function" and n.target == torch.add
    ]
    assert len(new_add_nodes) == 1
    new_add_node = new_add_nodes[0]

    assert new_add_node.meta.get("recompute") == "1"
    assert new_add_node.meta.get("ac_graph_id") == 1

    assert "recompute" not in node_s1.meta
    assert "recompute" not in node_s2.meta
    assert "ac_graph_id" not in node_s1.meta
    assert "ac_graph_id" not in node_s2.meta

    placeholders = [n for n in graph.nodes if n.op == "placeholder"]
    for n in placeholders:
        assert "recompute" not in n.meta
        assert "ac_graph_id" not in n.meta


if __name__ == "__main__":
    test_nested_boundary_stopping_logic()

### Error logs

Running the minimal reproducer fails with:

Traceback (most recent call last):
  File "/home/tiger/pattern-matcher-issue-repro/repro.py", line 70, in <module>
    test_nested_boundary_stopping_logic()
  File "/home/tiger/pattern-matcher-issue-repro/repro.py", line 58, in test_nested_boundary_stopping_logic
    assert "recompute" not in node_s1.meta
AssertionError

### Versions

PyTorch version: 2.7.1
Is debug build: False
CUDA used to build PyTorch: 12.6
Is CUDA available: False
Python version: 3.11.2 (main, Apr 28 2025, 14:11:48) [GCC 12.2.0]
Python platform: Linux-5.4.143.bsk.7-amd64-x86_64-with-glibc2.36

cc @soulitzer @ezyang @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu @voznesenskym @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

Fix Plan

To fix the issue with torch._inductor.pattern_matcher.Match.replace_with_graph() mishandling nested replacement args, we need to modify the replace_with_graph method to correctly handle the stop boundary for recompute-related tags.

Here are the steps to fix the issue:

  1. Flatten the pytree leaves: Before constructing the stop-node set, we need to flatten the pytree leaves of the replacement args.
  2. Update the stop-node set: Update the stop-node set to include the flattened pytree leaves.

Here's an example code snippet that demonstrates the fix:

def replace_with_graph(self, replacement_gm, args):
    # Flatten the pytree leaves of the replacement args
    flattened_args = []
    for arg in args:
        if isinstance(arg, tuple):
            flattened_args.extend(arg)
        else:
            flattened_args.append(arg)

    # Update the stop-node set
    stop_nodes = set()
    for node in self.graph.nodes:
        if node in flattened_args:
            stop_nodes.add(node)

    # ... (rest of the method remains the same)

Verification

To verify that the fix worked, we can run the minimal reproducer again and check that the assertions pass.

Extra Tips

  • Make sure to test the fix with different types of replacement args, including nested tuples and lists.
  • Consider adding additional tests to ensure that the fix does not introduce any regressions.
  • If you're using a version of PyTorch older than 2.7.1, you may need to apply additional patches or workarounds to fix related issues.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix torch._inductor.pattern_matcher.replace_with_graph leaks recompute tags past nested replacement args [1 pull requests, 1 participants]