pytorch - 💡(How to fix) Fix [Inductor] `F.threshold` bf16 miscompilation: scalar threshold not cast to tensor dtype

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

The threshold decomposition in torch/_refs/nn/functional/__init__.py (line 949):

return torch.where(a <= threshold, value, a)

When Inductor compiles this for a bf16 tensor a, the scalar threshold stays in float64 precision during the comparison. In eager mode, the comparison casts threshold to the tensor's dtype (bf16).

For threshold=0.1:

  • Eager: bf16(x) > bf16(0.1)0.10009765625 > 0.10009765625False → output = 0.0
  • Compiled: bf16(x) > float64(0.1)0.10009765625 > 0.1True → output = x

Code Example

import torch
import torch.nn.functional as F

x = torch.tensor([0.10009765625], dtype=torch.bfloat16, device='cuda')
# 0.10009765625 is the bf16 representation of 0.1

eager = F.threshold(x, 0.1, 0.0)
print(f"eager:    {eager}")    # tensor([0.], device='cuda:0', dtype=torch.bfloat16)  ← correct

torch._dynamo.reset()
compiled = torch.compile(F.threshold)(x, 0.1, 0.0)
print(f"compiled: {compiled}")  # tensor([0.1001], device='cuda:0', dtype=torch.bfloat16)WRONG

---

return torch.where(a <= threshold, value, a)

---

# hardshrink (line 527)FIXED, works correctly:
return torch.where(torch.abs(a) <= lambd, 0, a)

# softshrink (line 533)FIXED, works correctly

# threshold (line 949)NOT FIXED:
return torch.where(a <= threshold, value, a)

---

x = torch.tensor([0.10009765625], dtype=torch.bfloat16, device='cuda')
torch._dynamo.reset()
print(torch.allclose(F.hardshrink(x, lambd=0.1),
                     torch.compile(F.hardshrink)(x, lambd=0.1)))  # True ← fixed

torch._dynamo.reset()
print(torch.allclose(F.softshrink(x, lambd=0.1),
                     torch.compile(F.softshrink)(x, lambd=0.1)))  # True ← fixed
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

torch.compile(F.threshold) produces wrong results for bfloat16 tensors when the threshold value is not exactly representable in bf16. The compiled code compares the bf16 tensor against the float64 threshold, while eager mode casts the threshold to the tensor's dtype (bf16) before comparison.

This is the same root cause that was previously fixed for hardshrink and softshrink, but F.threshold was missed.

Minimal reproducer

import torch
import torch.nn.functional as F

x = torch.tensor([0.10009765625], dtype=torch.bfloat16, device='cuda')
# 0.10009765625 is the bf16 representation of 0.1

eager = F.threshold(x, 0.1, 0.0)
print(f"eager:    {eager}")    # tensor([0.], device='cuda:0', dtype=torch.bfloat16)  ← correct

torch._dynamo.reset()
compiled = torch.compile(F.threshold)(x, 0.1, 0.0)
print(f"compiled: {compiled}")  # tensor([0.1001], device='cuda:0', dtype=torch.bfloat16)  ← WRONG

Expected: compiled should match eagertensor([0.]).
Actual: compiled returns tensor([0.1001]).

The bug reproduces for any threshold value not exactly representable in bf16 (0.01, 0.1, 0.2, 0.3, etc.). fp32/fp64 tensors are NOT affected.

Root cause

The threshold decomposition in torch/_refs/nn/functional/__init__.py (line 949):

return torch.where(a <= threshold, value, a)

When Inductor compiles this for a bf16 tensor a, the scalar threshold stays in float64 precision during the comparison. In eager mode, the comparison casts threshold to the tensor's dtype (bf16).

For threshold=0.1:

  • Eager: bf16(x) > bf16(0.1)0.10009765625 > 0.10009765625False → output = 0.0
  • Compiled: bf16(x) > float64(0.1)0.10009765625 > 0.1True → output = x

Relation to existing fixes

hardshrink and softshrink had the exact same scalar precision issue and were fixed (they now match eager in bf16). Their current decompositions in the same file work correctly:

# hardshrink (line 527) — FIXED, works correctly:
return torch.where(torch.abs(a) <= lambd, 0, a)

# softshrink (line 533) — FIXED, works correctly

# threshold (line 949) — NOT FIXED:
return torch.where(a <= threshold, value, a)

Verification that hardshrink/softshrink are fixed:

x = torch.tensor([0.10009765625], dtype=torch.bfloat16, device='cuda')
torch._dynamo.reset()
print(torch.allclose(F.hardshrink(x, lambd=0.1),
                     torch.compile(F.hardshrink)(x, lambd=0.1)))  # True ← fixed

torch._dynamo.reset()
print(torch.allclose(F.softshrink(x, lambd=0.1),
                     torch.compile(F.softshrink)(x, lambd=0.1)))  # True ← fixed

Versions

Versions

  • PyTorch: 2.13.0.dev20260521+cu126
  • GPU: NVIDIA RTX A6000 (sm_86)
  • CUDA: 12.6 (bundled)
  • Python: 3.11
  • OS: Linux

cc @chauhang @penguinwu

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING