pytorch - 💡(How to fix) Fix `torch.compile` produces different output for model with `aten.select` parent + batch pointwise fusion pattern [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178879Fetched 2026-04-08 01:57:19
View on GitHub
Comments
2
Participants
2
Timeline
51
Reactions
0
Author
Timeline (top)
mentioned ×22subscribed ×22labeled ×5commented ×2

Error Message

Error logs

No error — outputs silently differ.

Root Cause

Root cause hypothesis

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn

class SelectParentPointwiseModel(nn.Module):
    def __init__(self, channels=64, height=32, width=32):
        super().__init__()
        self.channels = channels
        self.conv = nn.Conv2d(3, channels, kernel_size=3, padding=1)
        self.linear1 = nn.Linear(channels, channels)
        self.linear2 = nn.Linear(channels, channels)
        self.bn = nn.BatchNorm2d(channels)

    def forward(self, x):
        batch_size = x.shape[0]
        base = self.conv(x)
        base = self.bn(base)

        base_flat = base.flatten(2)
        y1 = self.linear1(base_flat.transpose(1, 2))
        y2 = self.linear2(base_flat.transpose(1, 2))
        y1 = y1.transpose(1, 2).view_as(base)
        y2 = y2.transpose(1, 2).view_as(base)

        # Multiple aten.select from same parent → expand → pointwise
        slice1 = base.select(1, 0).unsqueeze(1).expand_as(base)
        slice2 = base.select(1, 1).unsqueeze(1).expand_as(base)
        slice3 = base.select(1, 2).unsqueeze(1).expand_as(base)
        slice4 = base.select(1, 3).unsqueeze(1).expand_as(base)

        out1 = slice1 + y1
        out2 = slice2 * y2
        out3 = slice3 + y1
        out4 = slice4 * y2
        out5 = torch.add(slice1, y1, alpha=0.5)
        out6 = torch.mul(slice2, y2)

        combined = out1 + out2 + out3 + out4 + out5 + out6
        result = combined * base + combined / (base + 1e-8)
        return result


device = "cuda"
torch.manual_seed(42)
model = SelectParentPointwiseModel().to(device).eval()
x = torch.randn(4, 3, 32, 32, device=device)

# Eager
with torch.no_grad():
    eager_out = model(x)

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")
# Expected: max_diff > 1.0 — significant mismatch

---

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64)WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces numerically different results for a model that creates multiple aten.select slices from a common parent tensor, expands them, and performs batch pointwise operations (add, mul) with other tensors. The pattern targets Inductor's BatchPointwiseMathOpsPostGradFusion optimization pass, which fuses pointwise operations sharing the same parent tensor.

the maximum difference observed exceeding 21.0.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn

class SelectParentPointwiseModel(nn.Module):
    def __init__(self, channels=64, height=32, width=32):
        super().__init__()
        self.channels = channels
        self.conv = nn.Conv2d(3, channels, kernel_size=3, padding=1)
        self.linear1 = nn.Linear(channels, channels)
        self.linear2 = nn.Linear(channels, channels)
        self.bn = nn.BatchNorm2d(channels)

    def forward(self, x):
        batch_size = x.shape[0]
        base = self.conv(x)
        base = self.bn(base)

        base_flat = base.flatten(2)
        y1 = self.linear1(base_flat.transpose(1, 2))
        y2 = self.linear2(base_flat.transpose(1, 2))
        y1 = y1.transpose(1, 2).view_as(base)
        y2 = y2.transpose(1, 2).view_as(base)

        # Multiple aten.select from same parent → expand → pointwise
        slice1 = base.select(1, 0).unsqueeze(1).expand_as(base)
        slice2 = base.select(1, 1).unsqueeze(1).expand_as(base)
        slice3 = base.select(1, 2).unsqueeze(1).expand_as(base)
        slice4 = base.select(1, 3).unsqueeze(1).expand_as(base)

        out1 = slice1 + y1
        out2 = slice2 * y2
        out3 = slice3 + y1
        out4 = slice4 * y2
        out5 = torch.add(slice1, y1, alpha=0.5)
        out6 = torch.mul(slice2, y2)

        combined = out1 + out2 + out3 + out4 + out5 + out6
        result = combined * base + combined / (base + 1e-8)
        return result


device = "cuda"
torch.manual_seed(42)
model = SelectParentPointwiseModel().to(device).eval()
x = torch.randn(4, 3, 32, 32, device=device)

# Eager
with torch.no_grad():
    eager_out = model(x)

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")
# Expected: max_diff > 1.0 — significant mismatch

Behavior summary

ModeOutputMax Diff
EagerReference
torch.compileDiffers~21.4

Root cause hypothesis

The BatchPointwiseMathOpsPostGradFusion (or fuse_nodes_with_same_parent) pass in Inductor's post-grad optimization reorders or fuses the batch of pointwise operations that share a common aten.select parent. This fusion changes the computation order, introducing numerical differences that accumulate through the combined * base + combined / (base + 1e-8) final expression.

Error logs

No error — outputs silently differ.

Versions

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64) — WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The most likely fix is to disable the BatchPointwiseMathOpsPostGradFusion optimization pass in Inductor's post-grad optimization.

Guidance

  • Identify the specific optimization pass causing the issue: BatchPointwiseMathOpsPostGradFusion.
  • Consider disabling this pass when compiling the model with Inductor backend to prevent numerical differences.
  • Verify the fix by comparing the outputs of the compiled and eager models using torch.allclose.
  • If disabling the pass is not feasible, explore alternative optimization strategies or modify the model to avoid the problematic pattern.

Example

# Disable the BatchPointwiseMathOpsPostGradFusion pass
compiled = torch.compile(model, backend="inductor", disable_optimizations=["BatchPointwiseMathOpsPostGradFusion"])

Note: The above example is hypothetical, as the actual API to disable specific optimization passes is not provided in the issue.

Notes

The provided code snippet and analysis suggest that the issue is related to the optimization pass, but the exact solution may depend on the specific PyTorch and Inductor versions being used.

Recommendation

Apply workaround: Disable the BatchPointwiseMathOpsPostGradFusion pass when compiling the model with Inductor backend to prevent numerical differences. This is a temporary solution until the underlying issue is resolved.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile` produces different output for model with `aten.select` parent + batch pointwise fusion pattern [2 comments, 2 participants]