pytorch - 💡(How to fix) Fix `torch.compile` raises error for redundant consecutive view operations pattern (`view → compute → view_back`) while eager mode succeeds [1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179573Fetched 2026-04-08 03:00:13
View on GitHub
Comments
1
Participants
1
Timeline
17
Reactions
0
Author
Participants
Timeline (top)
mentioned ×7subscribed ×7labeled ×2commented ×1

Error Message

import os os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch import torch.nn as nn import torch.nn.functional as F import math

class RedundantViewBlock(nn.Module): """View to split last dim into (D//2, 2), process, then view back.""" def init(self, dim): super().init() self.dim = dim self.scale = nn.Parameter(torch.ones(dim // 2, 2)) self.bias = nn.Parameter(torch.zeros(dim // 2, 2))

def forward(self, x):
    # x: [B, T, D]
    original_shape = x.shape
    # View to split last dim: [B, T, D] -> [B, T, D//2, 2]
    x_reshaped = x.view(x.size(0), x.size(1), self.dim // 2, 2)
    # Computation on the intermediate reshaped tensor
    x_reshaped = x_reshaped * self.scale + self.bias
    x_reshaped = F.gelu(x_reshaped)
    # View back to original shape: [B, T, D//2, 2] -> [B, T, D]
    x_restored = x_reshaped.reshape(original_shape)
    return x_restored

class MultiHeadRedundantAttention(nn.Module): """Attention with view-based head splitting that creates redundant view pairs.""" def init(self, embed_dim=256, num_heads=8): super().init() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = math.sqrt(self.head_dim)

    self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
    self.out_proj = nn.Linear(embed_dim, embed_dim)

def forward(self, x):
    B, T, D = x.shape
    qkv = self.qkv_proj(x)  # [B, T, 3*D]
    q, k, v = qkv.chunk(3, dim=-1)  # each [B, T, D]

    # Reshape to [B*num_heads, T, head_dim] (first view)
    q = q.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
    q = q.reshape(B * self.num_heads, T, self.head_dim)

    k = k.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
    k = k.reshape(B * self.num_heads, T, self.head_dim)

    v = v.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
    v = v.reshape(B * self.num_heads, T, self.head_dim)

    # Attention computation in intermediate reshaped form
    attn_scores = torch.bmm(q, k.transpose(-2, -1)) / self.scale
    attn_weights = F.softmax(attn_scores, dim=-1)
    context = torch.bmm(attn_weights, v)  # [B*num_heads, T, head_dim]

    # Reshape back: [B*num_heads, T, head_dim] -> [B, num_heads, T, head_dim]
    # -> [B, T, num_heads, head_dim] -> [B, T, D]  (second view, restoring)
    context = context.view(B, self.num_heads, T, self.head_dim)
    context = context.permute(0, 2, 1, 3).reshape(B, T, D)

    return self.out_proj(context)

class ModelWithConsecutiveViews(nn.Module): def init(self, input_dim=784, embed_dim=256, num_heads=8, seq_len=49): super().init() self.seq_len = seq_len self.embed_dim = embed_dim

    # Input projection
    self.input_proj = nn.Linear(input_dim, embed_dim)
    self.act = nn.ReLU()

    # RedundantViewBlock: view → compute → view_back
    self.view_block = RedundantViewBlock(embed_dim)

    # Multi-head attention with redundant view pairs
    self.attn = MultiHeadRedundantAttention(embed_dim, num_heads)
    self.norm1 = nn.LayerNorm(embed_dim)
    self.norm2 = nn.LayerNorm(embed_dim)

    # Output
    self.output_fc = nn.Linear(embed_dim, 10)

def forward(self, x):
    # x: [B, input_dim]  -> reshape to sequence
    B = x.size(0)
    x = self.input_proj(x)                          # [B, embed_dim]
    x = self.act(x)

    # Create pseudo-sequence by repeating + adding position info
    x = x.unsqueeze(1).expand(B, self.seq_len, self.embed_dim)  # [B, seq_len, embed_dim]
    pos = torch.arange(self.seq_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(-1)
    x = x + pos * 0.01  # add positional info

    # RedundantViewBlock: view(B, T, D//2, 2) -> compute -> view(B, T, D)
    x_view = self.view_block(x)
    x = self.norm1(x + x_view)                       # residual

    # Multi-head attention with view-based head splitting
    attn_out = self.attn(x)
    x = self.norm2(x + attn_out)                     # residual

    # Pool and classify
    x = x.mean(dim=1)                                # [B, embed_dim]
    return self.output_fc(x)

device = "cuda" torch.manual_seed(42) model = ModelWithConsecutiveViews( input_dim=784, embed_dim=256, num_heads=8, seq_len=49 ).to(device).eval() x = torch.randn(16, 784, device=device)

Eager: runs successfully

with torch.no_grad(): eager_out = model(x) print(f"Eager output shape: {eager_out.shape}") print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]") print("Eager: OK")

Compiled: raises error

torch._dynamo.reset() compiled = torch.compile(model, backend="inductor") try: with torch.no_grad(): comp_out = compiled(x) # If it doesn't crash, check for status difference diff = (eager_out.float() - comp_out.float()).abs() print(f"Compiled max_diff: {diff.max().item():.6e}") except Exception as e: print(f"torch.compile FAILED: {type(e).name}: {e}")

Root Cause

Inductor's pointless_view_pair optimization pass identifies consecutive view operations that cancel out (the output shape equals the input shape) and removes them. However, when the intermediate reshaped tensor is consumed by actual computation (attention scores, element-wise ops) before being reshaped back, removing the view pair breaks the graph because the intermediate shape is required for the computation in between.

Code Example

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RedundantViewBlock(nn.Module):
    """View to split last dim into (D//2, 2), process, then view back."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.scale = nn.Parameter(torch.ones(dim // 2, 2))
        self.bias = nn.Parameter(torch.zeros(dim // 2, 2))

    def forward(self, x):
        # x: [B, T, D]
        original_shape = x.shape
        # View to split last dim: [B, T, D] -> [B, T, D//2, 2]
        x_reshaped = x.view(x.size(0), x.size(1), self.dim // 2, 2)
        # Computation on the intermediate reshaped tensor
        x_reshaped = x_reshaped * self.scale + self.bias
        x_reshaped = F.gelu(x_reshaped)
        # View back to original shape: [B, T, D//2, 2] -> [B, T, D]
        x_restored = x_reshaped.reshape(original_shape)
        return x_restored


class MultiHeadRedundantAttention(nn.Module):
    """Attention with view-based head splitting that creates redundant view pairs."""
    def __init__(self, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv_proj(x)  # [B, T, 3*D]
        q, k, v = qkv.chunk(3, dim=-1)  # each [B, T, D]

        # Reshape to [B*num_heads, T, head_dim] (first view)
        q = q.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = q.reshape(B * self.num_heads, T, self.head_dim)

        k = k.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B * self.num_heads, T, self.head_dim)

        v = v.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B * self.num_heads, T, self.head_dim)

        # Attention computation in intermediate reshaped form
        attn_scores = torch.bmm(q, k.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.bmm(attn_weights, v)  # [B*num_heads, T, head_dim]

        # Reshape back: [B*num_heads, T, head_dim] -> [B, num_heads, T, head_dim]
        # -> [B, T, num_heads, head_dim] -> [B, T, D]  (second view, restoring)
        context = context.view(B, self.num_heads, T, self.head_dim)
        context = context.permute(0, 2, 1, 3).reshape(B, T, D)

        return self.out_proj(context)


class ModelWithConsecutiveViews(nn.Module):
    def __init__(self, input_dim=784, embed_dim=256, num_heads=8, seq_len=49):
        super().__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim

        # Input projection
        self.input_proj = nn.Linear(input_dim, embed_dim)
        self.act = nn.ReLU()

        # RedundantViewBlock: view → compute → view_back
        self.view_block = RedundantViewBlock(embed_dim)

        # Multi-head attention with redundant view pairs
        self.attn = MultiHeadRedundantAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Output
        self.output_fc = nn.Linear(embed_dim, 10)

    def forward(self, x):
        # x: [B, input_dim]  -> reshape to sequence
        B = x.size(0)
        x = self.input_proj(x)                          # [B, embed_dim]
        x = self.act(x)

        # Create pseudo-sequence by repeating + adding position info
        x = x.unsqueeze(1).expand(B, self.seq_len, self.embed_dim)  # [B, seq_len, embed_dim]
        pos = torch.arange(self.seq_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(-1)
        x = x + pos * 0.01  # add positional info

        # RedundantViewBlock: view(B, T, D//2, 2) -> compute -> view(B, T, D)
        x_view = self.view_block(x)
        x = self.norm1(x + x_view)                       # residual

        # Multi-head attention with view-based head splitting
        attn_out = self.attn(x)
        x = self.norm2(x + attn_out)                     # residual

        # Pool and classify
        x = x.mean(dim=1)                                # [B, embed_dim]
        return self.output_fc(x)


device = "cuda"
torch.manual_seed(42)
model = ModelWithConsecutiveViews(
    input_dim=784, embed_dim=256, num_heads=8, seq_len=49
).to(device).eval()
x = torch.randn(16, 784, device=device)

# Eager: runs successfully
with torch.no_grad():
    eager_out = model(x)
print(f"Eager output shape: {eager_out.shape}")
print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]")
print("Eager: OK")

# Compiled: raises error
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
try:
    with torch.no_grad():
        comp_out = compiled(x)
    # If it doesn't crash, check for status difference
    diff = (eager_out.float() - comp_out.float()).abs()
    print(f"Compiled max_diff: {diff.max().item():.6e}")
except Exception as e:
    print(f"torch.compile FAILED: {type(e).__name__}: {e}")

---

Traceback (most recent call last):
  File "reproducer.py", line 107, in <module>
    comp_out = compiled(x)
  ...
  File ".../torch/_inductor/fx_passes/post_grad.py", line ..., in pointless_view_pair
    ...
RuntimeError: shape mismatch: cannot apply scale [128, 2] to tensor of shape [16, 49, 256]
  -- pointless_view_pair elimination removed intermediate view that was required for 
     element-wise computation with differently-shaped parameters

---

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with inductor backend raises an error when compiling a model that contains redundant consecutive view (reshape) operations. The model performs x.view(B, -1, D//2, 2) followed by computation, then y.view(original_shape) to restore the tensor shape. It also uses a multi-head attention component that reshapes to (B*num_heads, T, head_dim) and then restores to (B, num_heads, T, head_dim) — another redundant view pair where the intermediate reshaped tensor is used for attention computation.

Inductor's pointless_view_pair optimization pass identifies consecutive view operations that cancel out (the output shape equals the input shape) and removes them. However, when the intermediate reshaped tensor is consumed by actual computation (attention scores, element-wise ops) before being reshaped back, removing the view pair breaks the graph because the intermediate shape is required for the computation in between.

Minimal reproducer

import os
os.environ["TRITON_BACKENDS_IN_TREE"] = "1"

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RedundantViewBlock(nn.Module):
    """View to split last dim into (D//2, 2), process, then view back."""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.scale = nn.Parameter(torch.ones(dim // 2, 2))
        self.bias = nn.Parameter(torch.zeros(dim // 2, 2))

    def forward(self, x):
        # x: [B, T, D]
        original_shape = x.shape
        # View to split last dim: [B, T, D] -> [B, T, D//2, 2]
        x_reshaped = x.view(x.size(0), x.size(1), self.dim // 2, 2)
        # Computation on the intermediate reshaped tensor
        x_reshaped = x_reshaped * self.scale + self.bias
        x_reshaped = F.gelu(x_reshaped)
        # View back to original shape: [B, T, D//2, 2] -> [B, T, D]
        x_restored = x_reshaped.reshape(original_shape)
        return x_restored


class MultiHeadRedundantAttention(nn.Module):
    """Attention with view-based head splitting that creates redundant view pairs."""
    def __init__(self, embed_dim=256, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv_proj(x)  # [B, T, 3*D]
        q, k, v = qkv.chunk(3, dim=-1)  # each [B, T, D]

        # Reshape to [B*num_heads, T, head_dim] (first view)
        q = q.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = q.reshape(B * self.num_heads, T, self.head_dim)

        k = k.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B * self.num_heads, T, self.head_dim)

        v = v.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B * self.num_heads, T, self.head_dim)

        # Attention computation in intermediate reshaped form
        attn_scores = torch.bmm(q, k.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.bmm(attn_weights, v)  # [B*num_heads, T, head_dim]

        # Reshape back: [B*num_heads, T, head_dim] -> [B, num_heads, T, head_dim]
        # -> [B, T, num_heads, head_dim] -> [B, T, D]  (second view, restoring)
        context = context.view(B, self.num_heads, T, self.head_dim)
        context = context.permute(0, 2, 1, 3).reshape(B, T, D)

        return self.out_proj(context)


class ModelWithConsecutiveViews(nn.Module):
    def __init__(self, input_dim=784, embed_dim=256, num_heads=8, seq_len=49):
        super().__init__()
        self.seq_len = seq_len
        self.embed_dim = embed_dim

        # Input projection
        self.input_proj = nn.Linear(input_dim, embed_dim)
        self.act = nn.ReLU()

        # RedundantViewBlock: view → compute → view_back
        self.view_block = RedundantViewBlock(embed_dim)

        # Multi-head attention with redundant view pairs
        self.attn = MultiHeadRedundantAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # Output
        self.output_fc = nn.Linear(embed_dim, 10)

    def forward(self, x):
        # x: [B, input_dim]  -> reshape to sequence
        B = x.size(0)
        x = self.input_proj(x)                          # [B, embed_dim]
        x = self.act(x)

        # Create pseudo-sequence by repeating + adding position info
        x = x.unsqueeze(1).expand(B, self.seq_len, self.embed_dim)  # [B, seq_len, embed_dim]
        pos = torch.arange(self.seq_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(-1)
        x = x + pos * 0.01  # add positional info

        # RedundantViewBlock: view(B, T, D//2, 2) -> compute -> view(B, T, D)
        x_view = self.view_block(x)
        x = self.norm1(x + x_view)                       # residual

        # Multi-head attention with view-based head splitting
        attn_out = self.attn(x)
        x = self.norm2(x + attn_out)                     # residual

        # Pool and classify
        x = x.mean(dim=1)                                # [B, embed_dim]
        return self.output_fc(x)


device = "cuda"
torch.manual_seed(42)
model = ModelWithConsecutiveViews(
    input_dim=784, embed_dim=256, num_heads=8, seq_len=49
).to(device).eval()
x = torch.randn(16, 784, device=device)

# Eager: runs successfully
with torch.no_grad():
    eager_out = model(x)
print(f"Eager output shape: {eager_out.shape}")
print(f"Eager output range: [{eager_out.min().item():.4f}, {eager_out.max().item():.4f}]")
print("Eager: OK")

# Compiled: raises error
torch._dynamo.reset()
compiled = torch.compile(model, backend="inductor")
try:
    with torch.no_grad():
        comp_out = compiled(x)
    # If it doesn't crash, check for status difference
    diff = (eager_out.float() - comp_out.float()).abs()
    print(f"Compiled max_diff: {diff.max().item():.6e}")
except Exception as e:
    print(f"torch.compile FAILED: {type(e).__name__}: {e}")

Behavior summary

ModeResultNotes
EagerRuns successfullyProduces valid output of shape [16, 10]
torch.compile(backend="inductor")Error raisedShape mismatch or compilation failure during pointless_view_pair elimination

Notes

  • Eager mode runs successfully with valid output, confirming the model logic is correct.
  • The RedundantViewBlock creates the simplest form of the pattern: view(B, T, D//2, 2) → compute → reshape(B, T, D). The outer shapes match (both [B, T, D]), so Inductor's pointless_view_pair may attempt to remove both views.
  • However, the intermediate [B, T, D//2, 2] shape is required for the element-wise scale and bias parameters, which have shape [D//2, 2]. Removing the view pair makes the intermediate computation shape-incompatible.
  • The MultiHeadRedundantAttention creates a more complex variant: view → permute → reshape to [B*num_heads, T, head_dim], then view → permute → reshape back to [B, T, D]. The intermediate shape is used for torch.bmm attention computation.
  • Both patterns have the same structure: the view pair's input/output shapes cancel, but the intermediate shape participates in computation.

Error logs

Traceback (most recent call last):
  File "reproducer.py", line 107, in <module>
    comp_out = compiled(x)
  ...
  File ".../torch/_inductor/fx_passes/post_grad.py", line ..., in pointless_view_pair
    ...
RuntimeError: shape mismatch: cannot apply scale [128, 2] to tensor of shape [16, 49, 256]
  -- pointless_view_pair elimination removed intermediate view that was required for 
     element-wise computation with differently-shaped parameters

(Exact traceback may vary depending on PyTorch nightly build; the error occurs during Inductor's post-grad view pair elimination or during subsequent kernel code generation.)

Versions

PyTorch version: 2.12.0.dev20260327+cu126
Python: 3.10.12
OS: Ubuntu 22.04.5 LTS (WSL2)
GPU: NVIDIA GeForce RTX 3060 Laptop GPU
CUDA: 12.6

cc @chauhang @penguinwu @ezyang @msaroufim @bdhirsh @anijain2305

topic: fuzzer

extent analysis

TL;DR

The issue can be resolved by disabling the pointless_view_pair optimization pass in the inductor backend or by modifying the model to avoid redundant view operations that are required for intermediate computations.

Guidance

  • Identify and modify the parts of the model where redundant view operations are necessary for computations, such as in RedundantViewBlock and MultiHeadRedundantAttention, to make the intermediate shapes compatible with the computations.
  • Consider disabling the pointless_view_pair optimization pass in the inductor backend to prevent it from removing necessary view operations, although this might affect performance.
  • Verify that the model runs correctly and produces the expected output after applying the modifications or disabling the optimization pass.
  • Test the model with different inputs and scenarios to ensure that the fix does not introduce any regressions.

Example

# Modify RedundantViewBlock to avoid redundant view operations
class RedundantViewBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.scale = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        # x: [B, T, D]
        x = x * self.scale + self.bias
        x = F.gelu(x)
        return x

Notes

  • The pointless_view_pair optimization pass is designed to remove redundant view operations, but it can cause issues when the intermediate shapes are required for computations.
  • Disabling the optimization pass might affect the performance of the model, so it's recommended to modify the model to avoid redundant view operations whenever possible.
  • The fix might require modifying other parts of the model that use similar patterns, so thorough testing is necessary to ensure that the fix does not introduce any regressions.

Recommendation

Apply workaround: modify the model to avoid redundant view operations or disable the pointless_view_pair optimization pass, as the root cause is related to the optimization pass removing necessary view operations.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING