pytorch - 💡(How to fix) Fix `torch.compile` produces different output for model using `torch.einsum` projections with element-wise addition pattern [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178882Fetched 2026-04-08 01:57:14
View on GitHub
Comments
1
Participants
1
Timeline
77
Reactions
0
Author
Participants
Timeline (top)
mentioned ×35subscribed ×35labeled ×6commented ×1

Error Message

Error logs

No error — outputs silently differ.

Root Cause

Root cause hypothesis

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Batch-specific projection parameters
        self.q_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.k_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.v_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.out_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))

        self.q_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.k_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.v_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.out_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.ReLU(), nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        x_norm = self.norm1(x)

        # einsum projections + element-wise add (the target pattern)
        q = torch.einsum("b i d, b d h -> b i h", x_norm, self.q_proj) + self.q_bias[:, :seq_len, :]
        k = torch.einsum("b i d, b d h -> b i h", x_norm, self.k_proj) + self.k_bias[:, :seq_len, :]
        v = torch.einsum("b i d, b d h -> b i h", x_norm, self.v_proj) + self.v_bias[:, :seq_len, :]

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        out = torch.einsum("b i d, b d h -> b i h", out, self.out_proj) + self.out_bias[:, :seq_len, :]
        out = x + out
        out = out + self.ffn(self.norm2(out))
        return out


device = "cuda"
torch.manual_seed(42)
model = MultiHeadAttention(d_model=512, num_heads=8, batch_size=4, seq_len=16).to(device).eval()
x = torch.randn(4, 16, 512, device=device)

with torch.no_grad():
    eager_out = model(x)

torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")

---

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64)WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend produces different output for a Transformer-style model that uses torch.einsum('b i d, b d h -> b i h', ...) for projections followed by element-wise addition. The pattern targets Inductor's replace_einsum_to_pointwise optimization pass, which attempts to replace batched einsum operations with more efficient pointwise/matmul operations.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8, batch_size=4, seq_len=16):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Batch-specific projection parameters
        self.q_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.k_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.v_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))
        self.out_proj = nn.Parameter(torch.randn(batch_size, d_model, d_model))

        self.q_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.k_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.v_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))
        self.out_bias = nn.Parameter(torch.randn(batch_size, seq_len, d_model))

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4), nn.ReLU(), nn.Linear(d_model * 4, d_model)
        )

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        x_norm = self.norm1(x)

        # einsum projections + element-wise add (the target pattern)
        q = torch.einsum("b i d, b d h -> b i h", x_norm, self.q_proj) + self.q_bias[:, :seq_len, :]
        k = torch.einsum("b i d, b d h -> b i h", x_norm, self.k_proj) + self.k_bias[:, :seq_len, :]
        v = torch.einsum("b i d, b d h -> b i h", x_norm, self.v_proj) + self.v_bias[:, :seq_len, :]

        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)

        out = torch.einsum("b i d, b d h -> b i h", out, self.out_proj) + self.out_bias[:, :seq_len, :]
        out = x + out
        out = out + self.ffn(self.norm2(out))
        return out


device = "cuda"
torch.manual_seed(42)
model = MultiHeadAttention(d_model=512, num_heads=8, batch_size=4, seq_len=16).to(device).eval()
x = torch.randn(4, 16, 512, device=device)

with torch.no_grad():
    eager_out = model(x)

torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
with torch.no_grad():
    comp_out = compiled(x)

diff = (eager_out.float() - comp_out.float()).abs().max().item()
print(f"Max diff: {diff}")
print(f"Match: {torch.allclose(eager_out, comp_out, atol=1e-5, rtol=1e-4)}")

Behavior summary

ModeOutputMatch?
EagerReference
torch.compileDiffersNo

Root cause hypothesis

Inductor's replace_einsum_to_pointwise pass rewrites einsum('b i d, b d h -> b i h', A, B) as torch.bmm(A, B) or equivalent matmul. The replacement changes computation order (particularly with the batch-specific parameter tensors), causing accumulation order differences in floating-point arithmetic that become significant in the subsequent attention computation.

Error logs

No error — outputs silently differ.

Versions

PyTorch version: 2.12.0.dev20260327+cu126
CUDA used to build PyTorch: 12.6
OS: Ubuntu 22.04.5 LTS (x86_64) — WSL2
Python version: 3.10.12
GPU: NVIDIA GeForce RTX 3060 Laptop GPU

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be mitigated by disabling the replace_einsum_to_pointwise optimization pass in the Inductor backend or using a different backend that does not apply this optimization.

Guidance

  • Verify the hypothesis that the replace_einsum_to_pointwise pass is the root cause by disabling it and checking if the outputs match.
  • Consider using a different backend, such as the default PyTorch backend, to avoid the optimization pass altogether.
  • If the issue persists, try reducing the precision of the model or the inputs to see if the difference in outputs is due to floating-point arithmetic accumulation order differences.
  • Investigate if there are any other optimization passes in the Inductor backend that could be causing the issue.

Example

No code example is provided as the issue is related to the optimization pass in the Inductor backend, and the fix would involve configuring the backend rather than modifying the model code.

Notes

The issue is specific to the Inductor backend and the replace_einsum_to_pointwise optimization pass. The fix may not apply to other backends or optimization passes.

Recommendation

Apply workaround: disable the replace_einsum_to_pointwise optimization pass or use a different backend. This is because the issue is caused by the optimization pass, and disabling it or using a different backend can mitigate the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING