pytorch - 💡(How to fix) Fix `torch.compile` produces different output for model with `amax`/`amin` + partial reuse reduction pattern [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178883Fetched 2026-04-08 01:57:12
View on GitHub
Comments
2
Participants
2
Timeline
170
Reactions
0
Author
Participants
Timeline (top)
mentioned ×74subscribed ×74labeled ×12unlabeled ×7

Error Message

Inductor's reuse_partial optimization attempts to combine multiple reduction operations (amax dim=[2,3] + amax global, amin dim=[1] + min global) on the same tensor into shared partial computations. The optimization incorrectly carries forward an intermediate reduction result, causing the arithmetic features / (dim_amax + 1e-8) * global_amax to produce wildly different values. The magnitude of the difference (896.0) indicates an algorithmic error — not floating-point precision issue.

Error logs

No error — outputs silently differ with max_diff ≈ 896.0.

Root Cause

Root cause hypothesis

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiReductionModel(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.relu = nn.ReLU()
        self.conv3 = nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(hidden_dim // 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        features = self.relu(x)

        # Multiple reductions on the SAME tensor
        dim_amax = torch.amax(features, dim=[2, 3], keepdim=True)
        global_amax = torch.amax(features)
        dim_amin = torch.amin(features, dim=[1], keepdim=True)
        global_amin = torch.min(features)

        # Combine
        amax_normalized = features / (dim_amax + 1e-8)
        amax_scaled = amax_normalized * global_amax
        amin_centered = features - dim_amin
        amin_scaled = amin_centered / (torch.abs(global_amin) + 1e-8)
        combined = amax_scaled + amin_scaled

        output = self.conv3(combined)
        output = self.bn2(output)
        output = self.relu(output)
        return output


device = "cuda"
torch.manual_seed(42)
model = MultiReductionModel().to(device).eval()
x = torch.randn(2, 3, 32, 32, device=device)

with torch.no_grad():
    eager_out = model(x)

torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")
# Expected: max_diff ≈ 896.0 — very large mismatch

---

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64)WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces significantly different results (max_diff ≈ 896.0) for a model that performs multiple reduction operations (torch.amax, torch.amin, torch.min) on the same feature tensor, then combines them via normalization and centering arithmetic. The pattern targets Inductor's reuse_partial optimization, which attempts to reuse partial reduction results.

The maximum absolute difference observed was 896.0, indicating a critical numerical divergence — not mere floating-point precision drift.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiReductionModel(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.relu = nn.ReLU()
        self.conv3 = nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(hidden_dim // 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        features = self.relu(x)

        # Multiple reductions on the SAME tensor
        dim_amax = torch.amax(features, dim=[2, 3], keepdim=True)
        global_amax = torch.amax(features)
        dim_amin = torch.amin(features, dim=[1], keepdim=True)
        global_amin = torch.min(features)

        # Combine
        amax_normalized = features / (dim_amax + 1e-8)
        amax_scaled = amax_normalized * global_amax
        amin_centered = features - dim_amin
        amin_scaled = amin_centered / (torch.abs(global_amin) + 1e-8)
        combined = amax_scaled + amin_scaled

        output = self.conv3(combined)
        output = self.bn2(output)
        output = self.relu(output)
        return output


device = "cuda"
torch.manual_seed(42)
model = MultiReductionModel().to(device).eval()
x = torch.randn(2, 3, 32, 32, device=device)

with torch.no_grad():
    eager_out = model(x)

torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")
# Expected: max_diff ≈ 896.0 — very large mismatch

Behavior summary

ModeOutputMax Diff
EagerReference
torch.compileDiffers≈ 896.0

Root cause hypothesis

Inductor's reuse_partial optimization attempts to combine multiple reduction operations (amax dim=[2,3] + amax global, amin dim=[1] + min global) on the same tensor into shared partial computations. The optimization incorrectly carries forward an intermediate reduction result, causing the arithmetic features / (dim_amax + 1e-8) * global_amax to produce wildly different values. The magnitude of the difference (896.0) indicates an algorithmic error — not floating-point precision issue.

Error logs

No error — outputs silently differ with max_diff ≈ 896.0.

Versions

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64) — WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

Disable the reuse_partial optimization in Inductor to prevent incorrect combination of partial reduction results.

Guidance

  • Verify that the issue is indeed caused by the reuse_partial optimization by checking if disabling it resolves the problem.
  • Consider using a different backend or optimization strategy if the reuse_partial optimization is necessary for performance.
  • Test the model with different inputs and reduction operations to ensure that the issue is not specific to the current use case.
  • Check the PyTorch documentation for any known issues or limitations with the reuse_partial optimization.

Example

# Disable reuse_partial optimization
compiled = torch.compile(model, backend="inductor", full_graph=True)

Note: The full_graph=True argument disables the reuse_partial optimization.

Notes

The provided code snippet and analysis suggest that the issue is specific to the reuse_partial optimization in Inductor. However, without further information or testing, it is difficult to determine the root cause of the issue or provide a more comprehensive solution.

Recommendation

Apply workaround: Disable the reuse_partial optimization to prevent incorrect combination of partial reduction results. This is a safe and straightforward solution that can help resolve the issue, but it may impact performance.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING