pytorch - 💡(How to fix) Fix `torch.compile` produces numerically different results for pre-grad batch linear fusion pattern (multiple `F.linear` with same input) compared to eager mode [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179574Fetched 2026-04-08 03:00:11
View on GitHub
Comments
1
Participants
1
Timeline
17
Reactions
0
Author
Participants
Timeline (top)
mentioned ×7subscribed ×7labeled ×2commented ×1

Error Message

import torch import torch.nn as nn import torch.nn.functional as F

class Model(nn.Module): def init(self): super().init() # Conv feature extraction self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.pool = nn.AdaptiveAvgPool2d(4) # Embedding + permute path self.embedding = nn.Embedding(1000, 128) self.emb_proj = nn.Linear(128, 1024) # Multiple F.linear with different weights (batch linear fusion target) self.w1 = nn.Parameter(torch.randn(256, 1024)) self.b1 = nn.Parameter(torch.randn(256)) self.w2 = nn.Parameter(torch.randn(256, 1024)) self.b2 = nn.Parameter(torch.randn(256)) self.w3 = nn.Parameter(torch.randn(256, 1024)) self.b3 = nn.Parameter(torch.randn(256)) self.w4 = nn.Parameter(torch.randn(256, 1024)) self.b4 = nn.Parameter(torch.randn(256)) self.out_proj = nn.Linear(256, 10)

def forward(self, conv_input, embedding_input):
    # Conv path
    cx = self.pool(F.relu(self.conv(conv_input)))   # [B,64,4,4]
    cx = cx.flatten(1)                               # [B,1024]

    # Embedding path
    ex = self.embedding(embedding_input)             # [B,16,128]
    ex = ex.mean(dim=1)                              # [B,128]
    ex = self.emb_proj(ex)                           # [B,1024]

    # Fuse both paths
    x = cx + ex                                      # [B,1024]

    # Multiple F.linear with same input → PreGradBatchLinearFusion target
    h1 = F.linear(x, self.w1, self.b1)
    h1 = F.relu(h1)
    h2 = F.linear(x, self.w2, self.b2)
    h2 = F.relu(h2)
    h3 = F.linear(x, self.w3, self.b3)
    h3 = F.relu(h3)
    h4 = F.linear(x, self.w4, self.b4)
    h4 = F.relu(h4)

    combined = h1 + h2 + h3 + h4
    return self.out_proj(combined)

device = "cuda" torch.manual_seed(42) model = Model().to(device).eval() conv_input = torch.randn(4, 3, 32, 32, device=device) embedding_input = torch.randint(0, 1000, (4, 16), device=device)

Eager: deterministic

with torch.no_grad(): ref1 = model(conv_input, embedding_input) ref2 = model(conv_input, embedding_input) print(f"Eager deterministic: {(ref1 - ref2).abs().max().item():.6e}")

Compiled: different result

torch._dynamo.reset() compiled = torch.compile(model, backend="inductor") with torch.no_grad(): out = compiled(conv_input, embedding_input)

diff = (ref1.float() - out.float()).abs() print(f"max_diff={diff.max().item():.6e}") print(f"Output range: [{ref1.min().item():.2f}, {ref1.max().item():.2f}]") print(f"Relative error: {diff.max().item() / ref1.abs().max().item():.6e}")

Code Example

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv feature extraction
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(4)
        # Embedding + permute path
        self.embedding = nn.Embedding(1000, 128)
        self.emb_proj = nn.Linear(128, 1024)
        # Multiple F.linear with different weights (batch linear fusion target)
        self.w1 = nn.Parameter(torch.randn(256, 1024))
        self.b1 = nn.Parameter(torch.randn(256))
        self.w2 = nn.Parameter(torch.randn(256, 1024))
        self.b2 = nn.Parameter(torch.randn(256))
        self.w3 = nn.Parameter(torch.randn(256, 1024))
        self.b3 = nn.Parameter(torch.randn(256))
        self.w4 = nn.Parameter(torch.randn(256, 1024))
        self.b4 = nn.Parameter(torch.randn(256))
        self.out_proj = nn.Linear(256, 10)

    def forward(self, conv_input, embedding_input):
        # Conv path
        cx = self.pool(F.relu(self.conv(conv_input)))   # [B,64,4,4]
        cx = cx.flatten(1)                               # [B,1024]

        # Embedding path
        ex = self.embedding(embedding_input)             # [B,16,128]
        ex = ex.mean(dim=1)                              # [B,128]
        ex = self.emb_proj(ex)                           # [B,1024]

        # Fuse both paths
        x = cx + ex                                      # [B,1024]

        # Multiple F.linear with same input → PreGradBatchLinearFusion target
        h1 = F.linear(x, self.w1, self.b1)
        h1 = F.relu(h1)
        h2 = F.linear(x, self.w2, self.b2)
        h2 = F.relu(h2)
        h3 = F.linear(x, self.w3, self.b3)
        h3 = F.relu(h3)
        h4 = F.linear(x, self.w4, self.b4)
        h4 = F.relu(h4)

        combined = h1 + h2 + h3 + h4
        return self.out_proj(combined)

device = "cuda"
torch.manual_seed(42)
model = Model().to(device).eval()
conv_input = torch.randn(4, 3, 32, 32, device=device)
embedding_input = torch.randint(0, 1000, (4, 16), device=device)

# Eager: deterministic
with torch.no_grad():
    ref1 = model(conv_input, embedding_input)
    ref2 = model(conv_input, embedding_input)
print(f"Eager deterministic: {(ref1 - ref2).abs().max().item():.6e}")

# Compiled: different result
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    out = compiled(conv_input, embedding_input)

diff = (ref1.float() - out.float()).abs()
print(f"max_diff={diff.max().item():.6e}")
print(f"Output range: [{ref1.min().item():.2f}, {ref1.max().item():.2f}]")
print(f"Relative error: {diff.max().item() / ref1.abs().max().item():.6e}")

---

Eager deterministic: 0.000000e+00
max_diff > 0 (systematic mismatch from batch linear fusion)

---

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces numerically different results compared to eager mode when the model uses multiple F.linear calls with the same 2D input tensor but different weight matrices. This pattern triggers Inductor's PreGradBatchLinearFusion optimization, which combines multiple independent linear operations into a single batched GEMM. The batched execution uses different floating-point accumulation order, producing systematically different numerical results.

The model includes Conv2d feature extraction, an Embedding layer with permute + linear fusion, and multiple parallel F.linear calls sharing the same flattened input — the canonical pattern for pre-grad batch linear fusion.

Eager mode is perfectly deterministic (max_var=0 across runs), ruling out GPU non-determinism.

Minimal reproducer

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # Conv feature extraction
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(4)
        # Embedding + permute path
        self.embedding = nn.Embedding(1000, 128)
        self.emb_proj = nn.Linear(128, 1024)
        # Multiple F.linear with different weights (batch linear fusion target)
        self.w1 = nn.Parameter(torch.randn(256, 1024))
        self.b1 = nn.Parameter(torch.randn(256))
        self.w2 = nn.Parameter(torch.randn(256, 1024))
        self.b2 = nn.Parameter(torch.randn(256))
        self.w3 = nn.Parameter(torch.randn(256, 1024))
        self.b3 = nn.Parameter(torch.randn(256))
        self.w4 = nn.Parameter(torch.randn(256, 1024))
        self.b4 = nn.Parameter(torch.randn(256))
        self.out_proj = nn.Linear(256, 10)

    def forward(self, conv_input, embedding_input):
        # Conv path
        cx = self.pool(F.relu(self.conv(conv_input)))   # [B,64,4,4]
        cx = cx.flatten(1)                               # [B,1024]

        # Embedding path
        ex = self.embedding(embedding_input)             # [B,16,128]
        ex = ex.mean(dim=1)                              # [B,128]
        ex = self.emb_proj(ex)                           # [B,1024]

        # Fuse both paths
        x = cx + ex                                      # [B,1024]

        # Multiple F.linear with same input → PreGradBatchLinearFusion target
        h1 = F.linear(x, self.w1, self.b1)
        h1 = F.relu(h1)
        h2 = F.linear(x, self.w2, self.b2)
        h2 = F.relu(h2)
        h3 = F.linear(x, self.w3, self.b3)
        h3 = F.relu(h3)
        h4 = F.linear(x, self.w4, self.b4)
        h4 = F.relu(h4)

        combined = h1 + h2 + h3 + h4
        return self.out_proj(combined)

device = "cuda"
torch.manual_seed(42)
model = Model().to(device).eval()
conv_input = torch.randn(4, 3, 32, 32, device=device)
embedding_input = torch.randint(0, 1000, (4, 16), device=device)

# Eager: deterministic
with torch.no_grad():
    ref1 = model(conv_input, embedding_input)
    ref2 = model(conv_input, embedding_input)
print(f"Eager deterministic: {(ref1 - ref2).abs().max().item():.6e}")

# Compiled: different result
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    out = compiled(conv_input, embedding_input)

diff = (ref1.float() - out.float()).abs()
print(f"max_diff={diff.max().item():.6e}")
print(f"Output range: [{ref1.min().item():.2f}, {ref1.max().item():.2f}]")
print(f"Relative error: {diff.max().item() / ref1.abs().max().item():.6e}")

Behavior summary

ModeResultNotes
EagerReference outputPerfectly deterministic across runs (max_var=0)
torch.compile(backend="inductor")Different outputNumerical values differ beyond float32 tolerance

Notes

  • Eager mode is perfectly deterministic (max_var=0 across runs), confirming this is a systematic computation difference.
  • The PreGradBatchLinearFusion pass batches the 4 independent F.linear(x, w_i, b_i) calls into a single batched GEMM operation, changing the accumulation order.
  • With unnormalized random weights (randn scale), each linear layer has large output magnitudes, and ReLU non-linearities cause different activations when values are near zero.
  • The Embedding + permute path adds an additional dimension of numerical variation through the embedding lookup and mean reduction.

Error logs

No error — the compiled model produces a result, but it differs numerically from eager:

Eager deterministic: 0.000000e+00
max_diff > 0 (systematic mismatch from batch linear fusion)

Versions

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @chauhang @penguinwu @ezyang @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

TL;DR

The most likely fix is to disable the PreGradBatchLinearFusion optimization in the inductor backend or use a different backend that does not perform this optimization.

Guidance

  • Identify the specific optimization causing the issue: PreGradBatchLinearFusion batches independent linear operations, changing the accumulation order and resulting in different numerical values.
  • Consider disabling this optimization or using a different backend to avoid the batched GEMM operation.
  • Verify the fix by comparing the output of the compiled model with the eager mode output to ensure the numerical values match within the expected tolerance.
  • If disabling the optimization is not feasible, explore other workarounds, such as using a different numerical precision or scaling the weights to reduce the impact of the accumulation order change.

Example

No code example is provided as the issue is related to the optimization pass in the inductor backend, and the fix would involve modifying the backend or disabling the optimization.

Notes

The provided information suggests that the issue is specific to the inductor backend and the PreGradBatchLinearFusion optimization. Disabling this optimization or using a different backend may resolve the issue, but it may also impact performance.

Recommendation

Apply a workaround, such as disabling the PreGradBatchLinearFusion optimization, as the issue is specific to this optimization and the inductor backend. This workaround may have performance implications, but it should resolve the numerical mismatch issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING