pytorch - 💡(How to fix) Fix torch.compile Inductor produces larger bf16 numerical drift for RoPE rotate-half pattern [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183122Fetched 2026-05-11 03:12:45
View on GitHub
Comments
0
Participants
1
Timeline
75
Reactions
0
Participants
Timeline (top)
mentioned ×34subscribed ×34labeled ×7

Error Message

Error logs

Code Example

import torch

def rotate_half(x):
    h = x.shape[-1] // 2
    return torch.cat((-x[..., h:], x[..., :h]), dim=-1)

def rope(q, cos, sin):
    return q * cos.unsqueeze(1) + rotate_half(q) * sin.unsqueeze(1)

torch.manual_seed(99)

dt = torch.bfloat16
q = torch.randn(2, 3, 5, 8, dtype=dt)
cos = torch.randn(2, 5, 8, dtype=dt)
sin = torch.randn(2, 5, 8, dtype=dt)

eager = rope(q, cos, sin)

torch._dynamo.reset()
compiled = torch.compile(
    rope,
    backend="inductor",
    fullgraph=True,
    dynamic=True,
)(q, cos, sin)

diff = (eager.float() - compiled.float()).abs().max().item()

print(f"bf16 max_diff = {diff:.6f}")
print("BUG" if diff > 1.5 / 256 else "OK")
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with the Inductor backend produces a larger-than-expected bf16 numerical difference for a RoPE-style rotate-half pattern.

The function uses torch.cat, slicing, broadcasting, multiplication, and addition:

import torch

def rotate_half(x):
    h = x.shape[-1] // 2
    return torch.cat((-x[..., h:], x[..., :h]), dim=-1)

def rope(q, cos, sin):
    return q * cos.unsqueeze(1) + rotate_half(q) * sin.unsqueeze(1)

torch.manual_seed(99)

dt = torch.bfloat16
q = torch.randn(2, 3, 5, 8, dtype=dt)
cos = torch.randn(2, 5, 8, dtype=dt)
sin = torch.randn(2, 5, 8, dtype=dt)

eager = rope(q, cos, sin)

torch._dynamo.reset()
compiled = torch.compile(
    rope,
    backend="inductor",
    fullgraph=True,
    dynamic=True,
)(q, cos, sin)

diff = (eager.float() - compiled.float()).abs().max().item()

print(f"bf16 max_diff = {diff:.6f}")
print("BUG" if diff > 1.5 / 256 else "OK")

Error logs

bf16 max_diff = 0.015625 BUG

Versions

PyTorch version: 2.11.0+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.2 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.39

Python version: 3.10.20 (main, Mar 11 2026, 17:46:40) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39 Is CUDA available: False CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: N/A GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080 Laptop GPU Nvidia driver version: 545.92 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

Versions of relevant libraries: [pip3] numpy==2.2.6 [pip3] onnx==1.21.0 [pip3] onnx2torch==1.5.15 [pip3] onnxruntime==1.23.2 [pip3] torch==2.11.0 [pip3] torchvision==0.26.0 [pip3] triton==3.6.0

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.compile Inductor produces larger bf16 numerical drift for RoPE rotate-half pattern [1 participants]