pytorch - ✅(Solved) Fix `torch.compile(backend="inductor")` silently succeeds on `torch.matmul` (≥3D) with mismatched dtypes (float16 @ float32) where eager raises RuntimeError [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177630Fetched 2026-04-08 00:47:09
View on GitHub
Comments
1
Participants
2
Timeline
70
Reactions
0
Author
Participants
Timeline (top)
mentioned ×27subscribed ×27labeled ×11unlabeled ×2

Error Message

import torch

4D matmul with mismatched dtypes

a = torch.randn(2, 4, 8, 8, device="cuda", dtype=torch.float16) b = torch.randn(2, 4, 8, 64, device="cuda", dtype=torch.float32)

Eager: raises RuntimeError

try: out_eager = torch.matmul(a, b) print(f"eager: OK dtype={out_eager.dtype}") except RuntimeError as e: print(f"eager: ERROR — {e}")

Output: eager: ERROR — expected scalar type Half but found Float

Compiled with inductor: silently succeeds

torch._dynamo.reset()

@torch.compile(backend="inductor", fullgraph=True) def compiled_matmul(x, y): return torch.matmul(x, y)

try: out_compiled = compiled_matmul(a, b) print(f"compile: OK dtype={out_compiled.dtype}") except Exception as e: print(f"compile: ERROR — {e}")

Output: compile: OK dtype=torch.float32

Verify the compiled output is an implicit fp16→fp32 promotion

ref = torch.matmul(a.float(), b) print(f"max diff vs explicit promotion: {(out_compiled - ref).abs().max().item()}")

Output: max diff vs explicit promotion: 0.0

Root Cause

torch.matmul for ≥3D inputs decomposes into torch.bmm internally (via aten.matmul.default decomposition in torch/_prims_common/wrappers.py). The decomposition reshapes the inputs from ≥3D to 3D, then calls tensor1_expanded.bmm(tensor2_expanded). Inductor's lowering of aten.bmm (via @pw_cast_for_opmath decorator and tuned_bmm in torch/_inductor/lowering.py) implicitly promotes float16 → float32 without enforcing the same-dtype constraint that the eager CUDA/CPU kernels require.

The aot_eager backend correctly preserves the RuntimeError, confirming the bug is specifically in Inductor's lowering path.

Related: torch.bmm (3D) has the same root cause. torch.mm (2D) is NOT affected — it correctly raises RuntimeError in both eager and compiled modes.

Fix Action

Fixed

PR fix notes

PR #177858: Fix Inductor bmm mixed-dtype error handling

Description (problem / solution / changelog)

Fix #177630

Summary

Root cause

Inductor's custom aten.bmm lowering did not enforce the same-dtype contract that eager bmm and batched matmul require. That let the native matmul and autotuned kernel paths normalize mixed low-precision inputs and silently succeed where eager raises.

Proposed fix

Add an explicit dtype validation at the start of the Inductor aten.bmm lowering, along with the existing out_dtype contract checks, before selecting any backend. Add regression coverage for direct compiled bmm and decomposed higher-rank matmul with mixed dtypes.

Why this is the right long term fix

Higher-rank torch.matmul already decomposes through aten.bmm, so restoring eager's dtype contract in the shared bmm lowering fixes both user-visible surfaces at the point where Inductor diverged. The regression tests keep that contract pinned without changing unrelated matmul behavior.

Drafted via Codex, published after manual review by @bobrenjc93

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_torchinductor.py (modified, +50/-0)
  • torch/_meta_registrations.py (modified, +177/-2)
  • torch/fx/experimental/proxy_tensor.py (modified, +34/-12)

Code Example

import torch

# 4D matmul with mismatched dtypes
a = torch.randn(2, 4, 8, 8, device="cuda", dtype=torch.float16)
b = torch.randn(2, 4, 8, 64, device="cuda", dtype=torch.float32)

# Eager: raises RuntimeError
try:
    out_eager = torch.matmul(a, b)
    print(f"eager: OK dtype={out_eager.dtype}")
except RuntimeError as e:
    print(f"eager: ERROR — {e}")
# Output: eager: ERROR — expected scalar type Half but found Float

# Compiled with inductor: silently succeeds
torch._dynamo.reset()

@torch.compile(backend="inductor", fullgraph=True)
def compiled_matmul(x, y):
    return torch.matmul(x, y)

try:
    out_compiled = compiled_matmul(a, b)
    print(f"compile: OK dtype={out_compiled.dtype}")
except Exception as e:
    print(f"compile: ERROR — {e}")
# Output: compile: OK dtype=torch.float32

# Verify the compiled output is an implicit fp16→fp32 promotion
ref = torch.matmul(a.float(), b)
print(f"max diff vs explicit promotion: {(out_compiled - ref).abs().max().item()}")
# Output: max diff vs explicit promotion: 0.0

---

import torch
import torch.nn as nn
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, num_embeddings=10000):
        super().__init__()
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(num_embeddings, embed_dim)
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        x_emb = self.embedding(x)
        query = self.query_proj(x_emb)
        key = self.key_proj(x_emb)
        value = self.value_proj(x_emb)
        batch_size, seq_len, embed_dim = query.shape
        query = query.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        key = key.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        value = value.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        query = query / math.sqrt(query.size(-1))
        attn_scores = query @ key.transpose(-2, -1)
        attn_scores = attn_scores.float()         # upcast to float32 for stable softmax
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.half()         # downcast back to float16
        context = attn_weights @ value             # half @ float => error in eager, OK in compile
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)
        return self.out_proj(context)

model = SingleHeadAttention().cuda()  # model weights are float32
x = torch.randint(0, 10000, (2, 16), dtype=torch.long, device="cuda")

# Eager: RuntimeError
try:
    model(x)
    print("eager: OK")
except RuntimeError as e:
    print(f"eager: ERROR — {e}")

# Compiled: silently succeeds
torch._dynamo.reset()
compiled_model = torch.compile(model, backend="inductor", fullgraph=True)
try:
    out = compiled_model(x)
    print(f"compile: OK dtype={out.dtype}")
except Exception as e:
    print(f"compile: ERROR — {e}")

---

RuntimeError: expected scalar type Half but found Float

---

(no error — silently returns torch.float32 tensor of shape [2, 4, 8, 64])

---

Collecting environment information...
PyTorch version: 2.12.0.dev20260316+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  4 2025, 08:48:33) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU
Nvidia driver version: 546.30
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] torch==2.12.0.dev20260316+cu126
[pip3] triton==3.6.0+git9844da95
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile(backend="inductor") silently succeeds when executing torch.matmul with ≥3D inputs of mismatched dtypes (float16 and float32), while eager mode correctly raises RuntimeError: expected scalar type Half but found Float.

The compiled result is a float32 tensor numerically identical to the result of explicitly casting the float16 input to float32 first — confirming Inductor performs an implicit dtype promotion that bypasses the strict same-dtype constraint of the eager kernels.

This is NOT a timeout or numerical precision issue — it is a behavioral inconsistency where torch.compile changes the semantics of an operation that should raise an error.

Root cause

torch.matmul for ≥3D inputs decomposes into torch.bmm internally (via aten.matmul.default decomposition in torch/_prims_common/wrappers.py). The decomposition reshapes the inputs from ≥3D to 3D, then calls tensor1_expanded.bmm(tensor2_expanded). Inductor's lowering of aten.bmm (via @pw_cast_for_opmath decorator and tuned_bmm in torch/_inductor/lowering.py) implicitly promotes float16 → float32 without enforcing the same-dtype constraint that the eager CUDA/CPU kernels require.

The aot_eager backend correctly preserves the RuntimeError, confirming the bug is specifically in Inductor's lowering path.

Related: torch.bmm (3D) has the same root cause. torch.mm (2D) is NOT affected — it correctly raises RuntimeError in both eager and compiled modes.

Minimal reproducer

import torch

# 4D matmul with mismatched dtypes
a = torch.randn(2, 4, 8, 8, device="cuda", dtype=torch.float16)
b = torch.randn(2, 4, 8, 64, device="cuda", dtype=torch.float32)

# Eager: raises RuntimeError
try:
    out_eager = torch.matmul(a, b)
    print(f"eager: OK dtype={out_eager.dtype}")
except RuntimeError as e:
    print(f"eager: ERROR — {e}")
# Output: eager: ERROR — expected scalar type Half but found Float

# Compiled with inductor: silently succeeds
torch._dynamo.reset()

@torch.compile(backend="inductor", fullgraph=True)
def compiled_matmul(x, y):
    return torch.matmul(x, y)

try:
    out_compiled = compiled_matmul(a, b)
    print(f"compile: OK dtype={out_compiled.dtype}")
except Exception as e:
    print(f"compile: ERROR — {e}")
# Output: compile: OK dtype=torch.float32

# Verify the compiled output is an implicit fp16→fp32 promotion
ref = torch.matmul(a.float(), b)
print(f"max diff vs explicit promotion: {(out_compiled - ref).abs().max().item()}")
# Output: max diff vs explicit promotion: 0.0

Full model-level reproducer (as found by fuzzer)

This pattern arises naturally in attention models that use mixed-precision casting for numerically stable softmax:

import torch
import torch.nn as nn
import math

class SingleHeadAttention(nn.Module):
    def __init__(self, embed_dim=512, num_embeddings=10000):
        super().__init__()
        self.embed_dim = embed_dim
        self.embedding = nn.Embedding(num_embeddings, embed_dim)
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        x_emb = self.embedding(x)
        query = self.query_proj(x_emb)
        key = self.key_proj(x_emb)
        value = self.value_proj(x_emb)
        batch_size, seq_len, embed_dim = query.shape
        query = query.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        key = key.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        value = value.view(batch_size, seq_len, 1, embed_dim).permute(0, 2, 1, 3)
        query = query / math.sqrt(query.size(-1))
        attn_scores = query @ key.transpose(-2, -1)
        attn_scores = attn_scores.float()         # upcast to float32 for stable softmax
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.half()         # downcast back to float16
        context = attn_weights @ value             # half @ float => error in eager, OK in compile
        context = context.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)
        return self.out_proj(context)

model = SingleHeadAttention().cuda()  # model weights are float32
x = torch.randint(0, 10000, (2, 16), dtype=torch.long, device="cuda")

# Eager: RuntimeError
try:
    model(x)
    print("eager: OK")
except RuntimeError as e:
    print(f"eager: ERROR — {e}")

# Compiled: silently succeeds
torch._dynamo.reset()
compiled_model = torch.compile(model, backend="inductor", fullgraph=True)
try:
    out = compiled_model(x)
    print(f"compile: OK dtype={out.dtype}")
except Exception as e:
    print(f"compile: ERROR — {e}")

Ablation

SettingBehaviorConsistent with eager?
backend="inductor"Silently succeeds (float32)No
backend="inductor", fullgraph=FalseSilently succeeds (float32)No
backend="aot_eager"RuntimeErrorYes
Eager (no compile)RuntimeError(baseline)
CPU device (no CUDA)Same bug — eager errors, compile succeedsNo
torch.mm(fp16, fp32) (2D)Both eager and compile raise RuntimeErrorYes
torch.bmm(fp16, fp32) (3D)Same bug — eager errors, compile succeedsNo (same root cause)

Behavior summary

OperationEagertorch.compile(backend="inductor")Consistent?
torch.mm(fp16, fp32) (2D)RuntimeErrorRuntimeErrorYes
torch.bmm(fp16, fp32) (3D)RuntimeErrorSucceeds (float32)No
torch.matmul(fp16, fp32) (4D)RuntimeErrorSucceeds (float32)No

Discovery context

This was discovered via WhiteFox differential fuzzing on torch._inductor.fx_passes.fuse_attention SDPA patterns. The fuzzer-generated attention model uses an explicit mixed-precision casting pattern (from _sfdp_pattern_9): upcast attention scores to float32 for stable softmax, then downcast weights to float16 before the final attn_weights @ value matmul. With model weights in float32, the half @ float matmul triggers the bug.

Error logs

Eager mode (correct behavior):

RuntimeError: expected scalar type Half but found Float

torch.compile(backend="inductor") (incorrect — should raise the same error):

(no error — silently returns torch.float32 tensor of shape [2, 4, 8, 64])

Versions

Collecting environment information...
PyTorch version: 2.12.0.dev20260316+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  4 2025, 08:48:33) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU
Nvidia driver version: 546.30
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] torch==2.12.0.dev20260316+cu126
[pip3] triton==3.6.0+git9844da95

cc @nairbv @mruberry @ezyang @eellison @bdhirsh @bobrenjc93 @aorenste @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @msaroufim @anijain2305

topic: fuzzer

extent analysis

Fix Plan

To fix the issue of torch.compile(backend="inductor") silently succeeding when executing torch.matmul with ≥3D inputs of mismatched dtypes, we need to modify the Inductor's lowering path to enforce the same-dtype constraint.

Here are the concrete steps:

  • Modify the @pw_cast_for_opmath decorator in torch/_inductor/lowering.py to check for dtype mismatches and raise a RuntimeError if necessary.
  • Update the tuned_bmm function in torch/_inductor/lowering.py to explicitly check for dtype mismatches and raise a RuntimeError if necessary.

Example code:

# torch/_inductor/lowering.py

def pw_cast_for_opmath(func):
    def wrapper(*args, **kwargs):
        # Check for dtype mismatches
        if len(args) > 1 and not all(arg.dtype == args[0].dtype for arg in args):
            raise RuntimeError("expected scalar type {} but found {}".format(args[0].dtype, args[1].dtype))
        return func(*args, **kwargs)
    return wrapper

@pw_cast_for_opmath
def tuned_bmm(tensor1, tensor2):
    # ...
    return torch.bmm(tensor1, tensor2)

Verification

To verify that the fix worked, we can run the minimal reproducer and check that torch.compile(backend="inductor") raises a RuntimeError when executing torch.matmul with ≥3D inputs of mismatched dtypes.

Example code:

import torch

# 4D matmul with mismatched dtypes
a = torch.randn(2, 4, 8, 8, device="cuda", dtype=torch.float16)
b = torch.randn(2, 4, 8, 64, device="cuda", dtype=torch.float32)

# Compiled with inductor: should raise RuntimeError
torch._dynamo.reset()

@torch.compile(backend="inductor", fullgraph=True)
def compiled_matmul(x, y):
    return torch.matmul(x, y)

try:
    out_compiled = compiled_matmul(a, b)
    print(f"compile: OK dtype={out_compiled.dtype}")
except RuntimeError as e:
    print(f"compile: ERROR — {e}")
# Output: compile: ERROR — expected scalar type Half but found Float

Extra Tips

To prevent similar issues in the future, we should:

  • Always test PyTorch code with different dtypes and devices to ensure correctness.
  • Use tools like WhiteFox differential fuzzing to detect behavioral inconsistencies between eager and compiled modes.
  • Keep the Inductor's lowering path up-to-date with the latest PyTorch changes to ensure consistency with eager mode.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix `torch.compile(backend="inductor")` silently succeeds on `torch.matmul` (≥3D) with mismatched dtypes (float16 @ float32) where eager raises RuntimeError [1 pull requests, 1 comments, 2 participants]