pytorch - 💡(How to fix) Fix `torch.compile` produces different output for Transformer using `torch.einsum` projections compared to eager mode [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179575Fetched 2026-04-08 03:00:10
View on GitHub
Comments
0
Participants
1
Timeline
115
Reactions
0
Author
Participants
Timeline (top)
mentioned ×54subscribed ×54labeled ×7

Error Message

Error logs

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Batch-specific projection weights (triggers batched einsum pattern)
        self.q_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.k_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.v_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.out_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)

        # Batch-specific biases
        self.q_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.k_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.v_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.out_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # einsum projections + element-wise bias add (target pattern)
        q = torch.einsum("b i d, b d h -> b i h", x, self.q_proj) + self.q_bias[:batch_size, :seq_len, :]
        k = torch.einsum("b i d, b d h -> b i h", x, self.k_proj) + self.k_bias[:batch_size, :seq_len, :]
        v = torch.einsum("b i d, b d h -> b i h", x, self.v_proj) + self.v_bias[:batch_size, :seq_len, :]

        # Reshape to multi-head
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scale = self.head_dim ** 0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        # Merge heads and output projection
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = torch.einsum("b i d, b d h -> b i h", out, self.out_proj) + self.out_bias[:batch_size, :seq_len, :]
        return out


class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, batch_size, seq_len)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model),
        )

    def forward(self, x):
        # Pre-norm residual attention
        x_norm = self.norm1(x)
        attn_out = self.attn(x_norm)
        x = x + attn_out
        # Pre-norm residual FFN
        x = x + self.ffn(self.norm2(x))
        return x


device = "cuda"
torch.manual_seed(42)
model = TransformerBlock(d_model=512, num_heads=8, batch_size=4, seq_len=16).to(device).eval()
x = torch.randn(4, 16, 512, device=device)

# Eager: deterministic
with torch.no_grad():
    ref = model(x)
    ref2 = model(x)
print(f"Eager deterministic: {(ref - ref2).abs().max().item():.6e}")

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp = compiled(x)

diff = (ref.float() - comp.float()).abs()
print(f"Max diff: {diff.max().item():.6e}")
print(f"Mean diff: {diff.mean().item():.6e}")
print(f"Match (atol=1e-5): {torch.allclose(ref, comp, atol=1e-5, rtol=1e-4)}")

---

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces numerically different results for a Transformer block that uses torch.einsum("b i d, b d h -> b i h", ...) for Q/K/V and output projections with element-wise bias addition, followed by multi-head scaled dot-product attention, LayerNorm, and FFN with residual connections.

Inductor's replace_einsum_to_pointwise pass rewrites the batched einsum into optimized bmm/matmul operations. This replacement changes the floating-point accumulation order, and the numerical differences compound through the attention mechanism and residual connections, producing a significant output mismatch.

Eager mode is perfectly deterministic (max_var=0 across runs), confirming this is a systematic computation difference introduced by the Inductor optimization, not GPU non-determinism.

Note: The model uses nn.Parameter with a batch dimension baked into the weight tensors (e.g., torch.randn(batch_size, d_model, d_model)). This is how the fuzzer generated the model — it exercises the batched einsum pattern that triggers replace_einsum_to_pointwise.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Batch-specific projection weights (triggers batched einsum pattern)
        self.q_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.k_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.v_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)
        self.out_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model) * 0.02)

        # Batch-specific biases
        self.q_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.k_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.v_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)
        self.out_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model) * 0.02)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # einsum projections + element-wise bias add (target pattern)
        q = torch.einsum("b i d, b d h -> b i h", x, self.q_proj) + self.q_bias[:batch_size, :seq_len, :]
        k = torch.einsum("b i d, b d h -> b i h", x, self.k_proj) + self.k_bias[:batch_size, :seq_len, :]
        v = torch.einsum("b i d, b d h -> b i h", x, self.v_proj) + self.v_bias[:batch_size, :seq_len, :]

        # Reshape to multi-head
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scale = self.head_dim ** 0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) / scale
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)

        # Merge heads and output projection
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        out = torch.einsum("b i d, b d h -> b i h", out, self.out_proj) + self.out_bias[:batch_size, :seq_len, :]
        return out


class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, batch_size, seq_len)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model),
        )

    def forward(self, x):
        # Pre-norm residual attention
        x_norm = self.norm1(x)
        attn_out = self.attn(x_norm)
        x = x + attn_out
        # Pre-norm residual FFN
        x = x + self.ffn(self.norm2(x))
        return x


device = "cuda"
torch.manual_seed(42)
model = TransformerBlock(d_model=512, num_heads=8, batch_size=4, seq_len=16).to(device).eval()
x = torch.randn(4, 16, 512, device=device)

# Eager: deterministic
with torch.no_grad():
    ref = model(x)
    ref2 = model(x)
print(f"Eager deterministic: {(ref - ref2).abs().max().item():.6e}")

# Compiled
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp = compiled(x)

diff = (ref.float() - comp.float()).abs()
print(f"Max diff: {diff.max().item():.6e}")
print(f"Mean diff: {diff.mean().item():.6e}")
print(f"Match (atol=1e-5): {torch.allclose(ref, comp, atol=1e-5, rtol=1e-4)}")

Behavior summary

ModeResultNotes
EagerReference outputPerfectly deterministic across runs (max_var=0)
torch.compile(backend="inductor")Different outputNumerical differences from einsum→bmm replacement compound through attention + residuals

Notes

  • Eager mode is perfectly deterministic (max_var=0), ruling out GPU non-determinism.
  • Inductor's replace_einsum_to_pointwise pass rewrites einsum('b i d, b d h -> b i h', A, B) as torch.bmm(A, B) or equivalent matmul. The replacement changes accumulation order.
  • Numerical differences from the Q/K/V projections compound through the softmax attention mechanism and propagate through two residual connections, amplifying the final output difference.
  • The batch-specific nn.Parameter weight tensors (shape [batch_size, d_model, d_model]) are the precise pattern that triggers the batched einsum optimization path.

Error logs

outputs silently differ.

Versions

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @ezyang @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

TL;DR

The numerical differences between the eager and compiled modes can be mitigated by disabling the replace_einsum_to_pointwise pass in the Inductor backend or using a different backend that does not introduce such optimizations.

Guidance

  • Investigate the replace_einsum_to_pointwise pass in the Inductor backend to understand how it rewrites the batched einsum operations and determine if there are any configuration options to disable or modify this behavior.
  • Consider using a different backend, such as the default PyTorch backend, to see if the numerical differences persist.
  • Review the model architecture and the use of batch-specific nn.Parameter weight tensors to determine if there are any alternative implementations that could reduce the impact of the numerical differences.
  • Evaluate the tolerance of the model to numerical differences and determine if the observed differences are within an acceptable range for the specific use case.

Example

# Disable the replace_einsum_to_pointwise pass (hypothetical example)
torch.compile(model, backend="inductor", disable_replace_einsum_to_pointwise=True)

Note: The above example is hypothetical, and the actual implementation may vary depending on the Inductor backend API.

Notes

  • The numerical differences are introduced by the replace_einsum_to_pointwise pass, which rewrites the batched einsum operations as matmul operations, changing the accumulation order.
  • The use of batch-specific nn.Parameter weight tensors triggers the batched einsum optimization path, which may contribute to the numerical differences.
  • The observed numerical differences may be specific to the Inductor backend and the model architecture, and further investigation is needed to determine the root cause and develop a robust solution.

Recommendation

Apply a workaround by disabling the replace_einsum_to_pointwise pass or using a different backend, as the numerical differences may be specific to the Inductor backend and the model architecture.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING