pytorch - ✅(Solved) Fix [inductor] `x.to(torch.bfloat16).float()` precision cast silently elided during fusion [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179561Fetched 2026-04-08 03:00:25
View on GitHub
Comments
0
Participants
1
Timeline
44
Reactions
0
Author
Participants
Timeline (top)
mentioned ×18subscribed ×18labeled ×6cross-referenced ×1

When a computation graph contains an explicit fp32 → bfloat16 → fp32 round-trip cast (e.g. x.to(torch.bfloat16).float()), Inductor treats the round-trip as a no-op and eliminates both casts during graph lowering. This produces silently incorrect results because the bfloat16 cast intentionally truncates the mantissa from 23 bits to 7 bits.

The same issue also affects fp16 round-trips (x.half().float()).

Root Cause

When a computation graph contains an explicit fp32 → bfloat16 → fp32 round-trip cast (e.g. x.to(torch.bfloat16).float()), Inductor treats the round-trip as a no-op and eliminates both casts during graph lowering. This produces silently incorrect results because the bfloat16 cast intentionally truncates the mantissa from 23 bits to 7 bits.

Fix Action

Fixed

PR fix notes

PR #179572: [inductor] Preserve explicit low-precision FP casts during fusion

Description (problem / solution / changelog)

Fixes #179561

Summary

Inductor's compute-type promotion (codegen_upcast_to_fp32) silently elides explicit dtype casts like x.to(bf16).float(). When enabled (default), the bf16 downcast becomes x.to(tl.float32) — a no-op since x is already fp32 in the fused kernel. The mantissa truncation never happens.

Root cause: _convert_element_type gates use_compute_types on config.emulate_precision_casts (default False), so explicit user casts are treated as compute-type promotions rather than precision-narrowing ops.

Fix: Always set use_compute_types=False when the cast involves bf16/fp16, so the real low-precision cast is emitted in the Triton kernel. emulate_precision_casts continues to control FP fusion and pointwise barriers independently — this avoids the perf regression from disabling enable_fp_fusion globally.

The reproducer from the issue now matches eager:

def fn(x, w):
    x = torch.matmul(x, w)
    x = x.to(torch.bfloat16).float()  # mantissa truncation preserved
    x = x * torch.sigmoid(x)
    return x.sum(dim=1)

Before: max_diff = 0.109 (bf16 cast elided). After: max_diff ≈ 0 (matches eager).

Test plan

  • Added test_explicit_precision_cast_not_elided in test/inductor/test_cuda_repro.py — verifies fp32→bf16→fp32 round-trip preserves truncation without any config override
  • Existing test_emulate_precision_casts_* tests continue to pass (those tests now exercise a redundant code path since casts are always preserved)

Authored with Claude.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_cuda_repro.py (modified, +18/-0)
  • torch/_inductor/lowering.py (modified, +3/-4)

Code Example

import torch
import torch._dynamo

torch.manual_seed(0)
x = torch.randn(4, 32, 32, device='cuda')
w = torch.randn(32, 32, device='cuda')

def fn(x, w):
    x = torch.matmul(x, w)
    x = x.to(torch.bfloat16).float()  # intentional precision truncation
    x = x * torch.sigmoid(x)
    return x.sum(dim=1)

# Eager (correct)
eager_out = fn(x, w)

# Inductor (incorrect — cast is elided)
torch._dynamo.reset()
compiled_fn = torch.compile(fn, backend='inductor')
inductor_out = compiled_fn(x.clone(), w.clone())

# aot_eager (correct — preserves cast)
torch._dynamo.reset()
aot_fn = torch.compile(fn, backend='aot_eager')
aot_out = aot_fn(x.clone(), w.clone())

print('eager vs inductor: ', (eager_out - inductor_out).abs().max().item())
# => 0.109 (WRONG — should be ~0)
print('eager vs aot_eager:', (eager_out - aot_out).abs().max().item())
# => 0.0 (correct)

# Proof that inductor output matches "no cast at all":
def fn_no_cast(x, w):
    x = torch.matmul(x, w)
    # bf16 cast removed
    x = x * torch.sigmoid(x)
    return x.sum(dim=1)

no_cast_out = fn_no_cast(x, w)
print('inductor vs fn_without_cast:', (inductor_out - no_cast_out).abs().max().item())
# => 1.5e-5 (inductor behaves as if cast never existed)

---

# Topologically Sorted Source Nodes: [sigmoid, x_2, sum_1]
# Original ATen: [aten.sigmoid, aten.mul, aten.sum]
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

[inductor] x.to(torch.bfloat16).float() precision cast silently elided during fusion

Summary

When a computation graph contains an explicit fp32 → bfloat16 → fp32 round-trip cast (e.g. x.to(torch.bfloat16).float()), Inductor treats the round-trip as a no-op and eliminates both casts during graph lowering. This produces silently incorrect results because the bfloat16 cast intentionally truncates the mantissa from 23 bits to 7 bits.

The same issue also affects fp16 round-trips (x.half().float()).

Reproducer

import torch
import torch._dynamo

torch.manual_seed(0)
x = torch.randn(4, 32, 32, device='cuda')
w = torch.randn(32, 32, device='cuda')

def fn(x, w):
    x = torch.matmul(x, w)
    x = x.to(torch.bfloat16).float()  # intentional precision truncation
    x = x * torch.sigmoid(x)
    return x.sum(dim=1)

# Eager (correct)
eager_out = fn(x, w)

# Inductor (incorrect — cast is elided)
torch._dynamo.reset()
compiled_fn = torch.compile(fn, backend='inductor')
inductor_out = compiled_fn(x.clone(), w.clone())

# aot_eager (correct — preserves cast)
torch._dynamo.reset()
aot_fn = torch.compile(fn, backend='aot_eager')
aot_out = aot_fn(x.clone(), w.clone())

print('eager vs inductor: ', (eager_out - inductor_out).abs().max().item())
# => 0.109 (WRONG — should be ~0)
print('eager vs aot_eager:', (eager_out - aot_out).abs().max().item())
# => 0.0 (correct)

# Proof that inductor output matches "no cast at all":
def fn_no_cast(x, w):
    x = torch.matmul(x, w)
    # bf16 cast removed
    x = x * torch.sigmoid(x)
    return x.sum(dim=1)

no_cast_out = fn_no_cast(x, w)
print('inductor vs fn_without_cast:', (inductor_out - no_cast_out).abs().max().item())
# => 1.5e-5 (inductor behaves as if cast never existed)

Evidence from generated Triton code

Setting TORCH_LOGS=output_code confirms the cast is removed. The generated Triton kernel header is:

# Topologically Sorted Source Nodes: [sigmoid, x_2, sum_1]
# Original ATen: [aten.sigmoid, aten.mul, aten.sum]

The aten._to_copy nodes for bfloat16 and float32 are absent from the fused kernel. The Triton kernel directly loads the fp32 matmul output and applies sigmoid → mul → sum without any intermediate dtype conversion.

Impact

This affects any user code that uses explicit dtype casts for precision control, which is common in:

  • Mixed-precision training: x.to(torch.bfloat16).float() is used to simulate low-precision forward passes
  • Quantization-aware training: explicit cast round-trips model quantization noise
  • Numerical stability testing: intentional precision reduction to test robustness

The diff scales with tensor values — larger matmul outputs (from larger hidden dimensions) produce larger errors:

Shapemax_diff (bf16)max_diff (fp16)
(4, 32, 32)0.1090.018
(8, 69, 53)0.2640.025
(3, 257, 129)0.704

Additional notes

  • 100% reproducible across all random seeds, all shapes, all config presets (default, aggressive_fusion, no_fusion, freezing)
  • Not affected by torch._inductor.config.emulate_precision_casts — setting it to True does not fix the issue
  • aot_eager is correct — the cast is preserved in the AOT Autograd graph, and only eliminated during Inductor lowering/codegen
  • The fp16 round-trip (x.half().float()) has the same problem but with smaller magnitude

Versions

  • PyTorch: 2.6.0+cu124
  • Triton: 3.2.0
  • CUDA: 12.4
  • GPU: Tesla T4
  • Python: 3.10.19
  • OS: Linux 5.4.0-42-generic x86_64

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be worked around by using the aot_eager backend instead of inductor for torch compilation.

Guidance

  • The problem arises from Inductor treating the round-trip cast as a no-op and eliminating both casts during graph lowering, resulting in silently incorrect results.
  • To verify the issue, compare the results of the eager and inductor outputs, as shown in the reproducer code.
  • As a temporary workaround, use the aot_eager backend, which preserves the cast and produces correct results.
  • Be aware that this issue affects not only bfloat16 but also fp16 round-trips, although with smaller magnitude.

Example

No code snippet is provided as the issue is more related to the backend behavior rather than a specific code fix.

Notes

The issue is 100% reproducible across different configurations and is not affected by torch._inductor.config.emulate_precision_casts. The aot_eager backend is a viable workaround, but it may have performance implications compared to inductor.

Recommendation

Apply the workaround by using the aot_eager backend for torch compilation, as it preserves the precision cast and produces correct results. This is recommended because it ensures accuracy in computations involving explicit dtype casts for precision control, which is crucial in mixed-precision training, quantization-aware training, and numerical stability testing.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [inductor] `x.to(torch.bfloat16).float()` precision cast silently elided during fusion [1 pull requests, 1 participants]