pytorch - 💡(How to fix) Fix `torch.compile(..., backend="inductor")` returns stale float32 for dtype=None factory op after torch.set_default_dtype(torch.float64) [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#184405Fetched 2026-05-20 03:38:55
View on GitHub
Comments
0
Participants
1
Timeline
35
Reactions
0
Participants
Timeline (top)
mentioned ×15subscribed ×15labeled ×5

Root Cause

This is a silent correctness issue because the compiled result has a different dtype from eager mode.

Reproduction

import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("TORCHINDUCTOR_DISABLE_PROGRESS", "1")

Code Example

import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("TORCHINDUCTOR_DISABLE_PROGRESS", "1")

import torch
torch.set_num_threads(1)
torch._dynamo.reset()

# 1. Test multiple factory APIs that rely on the implicit global dtype
def test_factory_apis():
    return {
        "ones": torch.ones(2),
        "zeros": torch.zeros(2),
        "tensor": torch.tensor([1.0, 2.0]),
        "arange": torch.arange(2.0),
        "empty": torch.empty(2)
    }

compiled_test = torch.compile(test_factory_apis, backend="inductor")

# 2. Phase 1: Compile under float32 (Dynamo bakes in dtype and caches the graph)
torch.set_default_dtype(torch.float32)
_ = compiled_test()  # Warmup & compile

# 3. Phase 2: Change global dtype and compare Eager vs Compiled behaviors
torch.set_default_dtype(torch.float64)
eager_results = test_factory_apis()
compiled_results = compiled_test()

# 4. Print a markdown-friendly report
print(f"PyTorch Version: {torch.__version__}\n")
print(f"| API Name   | Eager (Expected) | Compiled (Actual) | Bug Present? |")
print(f"|------------|------------------|-------------------|--------------|")

for api in eager_results.keys():
    eager_dt = str(eager_results[api].dtype).replace("torch.", "")
    comp_dt = str(compiled_results[api].dtype).replace("torch.", "")
    has_bug = "❌ YES" if eager_dt != comp_dt else "✅ NO"
    
    print(f"| {api:<10} | {eager_dt:<16} | {comp_dt:<17} | {has_bug:<12} |")

---

PyTorch Version: 2.10.0+cpu

| API Name   | Eager (Expected) | Compiled (Actual) | Bug Present? |
|------------|------------------|-------------------|--------------|
| ones       | float64          | float32           |YES        |
| zeros      | float64          | float32           |YES        |
| tensor     | float64          | float32           |YES        |
| arange     | float64          | float32           |YES        |
| empty      | float64          | float32           |YES        |
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile with the inductor backend appears to capture or reuse the default floating-point dtype from compile time for a factory op whose dtype argument is omitted.

In the repro below, the function is compiled while the global default dtype is torch.float32. After compilation, the global default dtype is changed to torch.float64. Eager mode then correctly returns a torch.float64 tensor, but the compiled function still returns torch.float32.

This is a silent correctness issue because the compiled result has a different dtype from eager mode.

Reproduction

import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("TORCHINDUCTOR_DISABLE_PROGRESS", "1")

import torch
torch.set_num_threads(1)
torch._dynamo.reset()

# 1. Test multiple factory APIs that rely on the implicit global dtype
def test_factory_apis():
    return {
        "ones": torch.ones(2),
        "zeros": torch.zeros(2),
        "tensor": torch.tensor([1.0, 2.0]),
        "arange": torch.arange(2.0),
        "empty": torch.empty(2)
    }

compiled_test = torch.compile(test_factory_apis, backend="inductor")

# 2. Phase 1: Compile under float32 (Dynamo bakes in dtype and caches the graph)
torch.set_default_dtype(torch.float32)
_ = compiled_test()  # Warmup & compile

# 3. Phase 2: Change global dtype and compare Eager vs Compiled behaviors
torch.set_default_dtype(torch.float64)
eager_results = test_factory_apis()
compiled_results = compiled_test()

# 4. Print a markdown-friendly report
print(f"PyTorch Version: {torch.__version__}\n")
print(f"| API Name   | Eager (Expected) | Compiled (Actual) | Bug Present? |")
print(f"|------------|------------------|-------------------|--------------|")

for api in eager_results.keys():
    eager_dt = str(eager_results[api].dtype).replace("torch.", "")
    comp_dt = str(compiled_results[api].dtype).replace("torch.", "")
    has_bug = "❌ YES" if eager_dt != comp_dt else "✅ NO"
    
    print(f"| {api:<10} | {eager_dt:<16} | {comp_dt:<17} | {has_bug:<12} |")

output:

PyTorch Version: 2.10.0+cpu

| API Name   | Eager (Expected) | Compiled (Actual) | Bug Present? |
|------------|------------------|-------------------|--------------|
| ones       | float64          | float32           | ❌ YES        |
| zeros      | float64          | float32           | ❌ YES        |
| tensor     | float64          | float32           | ❌ YES        |
| arange     | float64          | float32           | ❌ YES        |
| empty      | float64          | float32           | ❌ YES        |

Versions

PyTorch Version: 2.10.0+cpu

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @jataylo @azahed98

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile(..., backend="inductor")` returns stale float32 for dtype=None factory op after torch.set_default_dtype(torch.float64) [1 participants]