pytorch - 💡(How to fix) Fix torch.compile and eager disagree on NaN/Inf behavior for the same float32 Conv2d input (CUDA) [5 comments, 4 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178133Fetched 2026-04-08 01:16:26
View on GitHub
Comments
5
Participants
4
Timeline
15
Reactions
0
Author
Timeline (top)
commented ×5labeled ×5mentioned ×2subscribed ×2

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 40 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 48 On-line CPU(s) list: 0-47 Vendor ID: GenuineIntel Model name: QEMU Virtual CPU version 2.5+ CPU family: 15 Model: 107 Thread(s) per core: 1 Core(s) per socket: 48 Socket(s): 1 Stepping: 1 BogoMIPS: 4190.15 Flags: fpu de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx lm constant_tsc nopl xtopology cpuid tsc_known_freq pni ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c hypervisor lahf_lm abm cpuid_fault pti bmi1 avx2 bmi2 avx512f avx512dq avx512cd avx512bw avx512vl Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB (48 instances) L1i cache: 1.5 MiB (48 instances) L2 cache: 192 MiB (48 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-47 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

Code Example

import os
import torch
import torch.nn as nn

_DIR = os.path.dirname(os.path.abspath(__file__))
SD_PATH = os.path.join(_DIR, "sd.pt")
INPUT_PATH = os.path.join(_DIR, "input.pt")
BASELINE = "cuda+float32+eager"
SWITCH = "cuda+float32+compile"

def has_nan_inf(t: torch.Tensor):
    t = t.detach().float().cpu()
    return bool(torch.isnan(t).any().item()), bool(torch.isinf(t).any().item())

x = torch.load(INPUT_PATH, map_location="cpu").float()
sd = torch.load(SD_PATH, map_location="cpu")
w = next((sd[k] for k in ("model.block1.0.weight", "module.model.block1.0.weight", "_orig_mod.model.block1.0.weight") if k in sd), None)
b = next((sd[k] for k in ("model.block1.0.bias", "module.model.block1.0.bias", "_orig_mod.model.block1.0.bias") if k in sd), None)
if w is None or b is None:
    raise KeyError("Cannot find model.block1.0 weight/bias in sd.pt")

print(f"modes: {BASELINE} vs {SWITCH}")
print(f"input min={x.min().item():.6g} max={x.max().item():.6g}")
print(f"weight min={w.min().item():.6g} max={w.max().item():.6g}")
print(f"bias min={b.min().item():.6g} max={b.max().item():.6g}")

conv = nn.Conv2d(3, 64, kernel_size=9, padding=4, bias=True).to("cuda", torch.float32).eval()
with torch.no_grad():
    conv.weight.copy_(w.float())
    conv.bias.copy_(b.float())
    x_cuda = x.to("cuda", torch.float32)
    y_eager = conv(x_cuda)
    y_compile = torch.compile(conv, dynamic=False)(x_cuda)

e_nan, e_inf = has_nan_inf(y_eager)
c_nan, c_inf = has_nan_inf(y_compile)
print(f"base(has_nan={e_nan},has_inf={e_inf}) switch(has_nan={c_nan},has_inf={c_inf})")

---

modes: cuda+float32+eager vs cuda+float32+compile
input min=9.96758e+28 max=3.40282e+38
weight min=-0.0641492 max=0.0641493
bias min=-0.0587277 max=0.0636346
base(has_nan=True,has_inf=True) switch(has_nan=False,has_inf=False)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

I found a reproducible inconsistency between eager and torch.compile on CUDA for the same single Conv2d layer (model.block1.0), same input, and same weights.

modes: cuda+float32+eager vs cuda+float32+compile input min/max: 9.96758e+28 / 3.40282e+38 weight min/max: -0.0641492 / 0.0641493 bias min/max: -0.0587277 / 0.0636346 eager output: has_nan=True, has_inf=True compile output: has_nan=False, has_inf=False The input values are very large but still valid float32 data. Even with large valid inputs, switching execution mode should not cause this kind of semantic inconsistency where one mode produces NaN/Inf and the other does not for the same operation and parameters.

import os
import torch
import torch.nn as nn

_DIR = os.path.dirname(os.path.abspath(__file__))
SD_PATH = os.path.join(_DIR, "sd.pt")
INPUT_PATH = os.path.join(_DIR, "input.pt")
BASELINE = "cuda+float32+eager"
SWITCH = "cuda+float32+compile"

def has_nan_inf(t: torch.Tensor):
    t = t.detach().float().cpu()
    return bool(torch.isnan(t).any().item()), bool(torch.isinf(t).any().item())

x = torch.load(INPUT_PATH, map_location="cpu").float()
sd = torch.load(SD_PATH, map_location="cpu")
w = next((sd[k] for k in ("model.block1.0.weight", "module.model.block1.0.weight", "_orig_mod.model.block1.0.weight") if k in sd), None)
b = next((sd[k] for k in ("model.block1.0.bias", "module.model.block1.0.bias", "_orig_mod.model.block1.0.bias") if k in sd), None)
if w is None or b is None:
    raise KeyError("Cannot find model.block1.0 weight/bias in sd.pt")

print(f"modes: {BASELINE} vs {SWITCH}")
print(f"input min={x.min().item():.6g} max={x.max().item():.6g}")
print(f"weight min={w.min().item():.6g} max={w.max().item():.6g}")
print(f"bias min={b.min().item():.6g} max={b.max().item():.6g}")

conv = nn.Conv2d(3, 64, kernel_size=9, padding=4, bias=True).to("cuda", torch.float32).eval()
with torch.no_grad():
    conv.weight.copy_(w.float())
    conv.bias.copy_(b.float())
    x_cuda = x.to("cuda", torch.float32)
    y_eager = conv(x_cuda)
    y_compile = torch.compile(conv, dynamic=False)(x_cuda)

e_nan, e_inf = has_nan_inf(y_eager)
c_nan, c_inf = has_nan_inf(y_compile)
print(f"base(has_nan={e_nan},has_inf={e_inf}) switch(has_nan={c_nan},has_inf={c_inf})")
modes: cuda+float32+eager vs cuda+float32+compile
input min=9.96758e+28 max=3.40282e+38
weight min=-0.0641492 max=0.0641493
bias min=-0.0587277 max=0.0636346
base(has_nan=True,has_inf=True) switch(has_nan=False,has_inf=False)

The reproduction files are here:https://drive.google.com/file/d/1WJ10_8v5txv93M9FOFLYwpQhghwyEN_Y/view?usp=drive_link

Versions

PyTorch version: 2.10.0+cu126 Is debug build: False CUDA used to build PyTorch: 12.6 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.39

Python version: 3.10.19 | packaged by conda-forge | (main, Jan 26 2026, 23:45:08) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-6.8.0-90-generic-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.6.20 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 GPU 1: NVIDIA GeForce RTX 3090

Nvidia driver version: 560.35.03 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 40 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 48 On-line CPU(s) list: 0-47 Vendor ID: GenuineIntel Model name: QEMU Virtual CPU version 2.5+ CPU family: 15 Model: 107 Thread(s) per core: 1 Core(s) per socket: 48 Socket(s): 1 Stepping: 1 BogoMIPS: 4190.15 Flags: fpu de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx lm constant_tsc nopl xtopology cpuid tsc_known_freq pni ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c hypervisor lahf_lm abm cpuid_fault pti bmi1 avx2 bmi2 avx512f avx512dq avx512cd avx512bw avx512vl Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB (48 instances) L1i cache: 1.5 MiB (48 instances) L2 cache: 192 MiB (48 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-47 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Meltdown: Mitigation; PTI Vulnerability Mmio stale data: Unknown: No mitigations Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Not affected

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.6.4.1 [pip3] nvidia-cuda-cupti-cu12==12.6.80 [pip3] nvidia-cuda-nvrtc-cu12==12.6.77 [pip3] nvidia-cuda-runtime-cu12==12.6.77 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cufft-cu12==11.3.0.4 [pip3] nvidia-curand-cu12==10.3.7.77 [pip3] nvidia-cusolver-cu12==11.7.1.2 [pip3] nvidia-cusparse-cu12==12.5.4.2 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.6.85 [pip3] nvidia-nvtx-cu12==12.6.77 [pip3] onnxruntime-gpu==1.23.2 [pip3] optree==0.18.0 [pip3] pytorch-triton==3.2.0+git4b3bb1f8 [pip3] torch==2.10.0+cu126 [pip3] torchaudio==2.11.0.dev20260127+cu126 [pip3] torchvision==0.25.0+cu126 [pip3] triton==3.6.0+git9844da95 [conda] numpy 1.26.4 pypi_0 pypi [conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi [conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi [conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi [conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi [conda] optree 0.18.0 pypi_0 pypi [conda] pytorch-triton 3.2.0+git4b3bb1f8 pypi_0 pypi [conda] torch 2.10.0+cu126 pypi_0 pypi [conda] torchaudio 2.11.0.dev20260127+cu126 pypi_0 pypi [conda] torchvision 0.25.0+cu126 pypi_0 pypi [conda] triton 3.6.0+git9844da95 pypi_0 pypi

cc @chauhang @penguinwu

extent analysis

Fix Plan

The issue seems to be related to numerical instability in the Conv2d layer when using torch.compile. To fix this, we can try the following:

  • Clip input values: Clip the input values to a reasonable range to prevent overflow.
  • Use a different convolution implementation: Try using a different convolution implementation, such as torch.nn.functional.conv2d, to see if the issue persists.

Here are the concrete steps:

  1. Clip input values:
    • Use torch.clamp to clip the input values to a reasonable range.
    • Example code:

x_clipped = torch.clamp(x, min=-1e10, max=1e10) y_eager = conv(x_clipped.to("cuda", torch.float32)) y_compile = torch.compile(conv, dynamic=False)(x_clipped.to("cuda", torch.float32))

2.  **Use a different convolution implementation**:
    *   Replace `nn.Conv2d` with `torch.nn.functional.conv2d`.
    *   Example code:
    ```python
import torch.nn.functional as F

weight = conv.weight
bias = conv.bias
y_eager = F.conv2d(x_cuda, weight, bias, padding=4)
y_compile = torch.compile(lambda x: F.conv2d(x, weight, bias, padding=4), dynamic=False)(x_cuda)

Verification

To verify that the fix worked, run the modified code and check if the output of y_eager and y_compile are consistent and do not contain NaN or Inf values.

Extra Tips

  • Make sure to test the fix with different input values and models to ensure that it works in all cases.
  • If the issue persists, try using a different version of PyTorch or CUDA to see if the issue is specific to the current version.
  • Consider filing a bug report with PyTorch if the issue is not resolved with the above fixes.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.compile and eager disagree on NaN/Inf behavior for the same float32 Conv2d input (CUDA) [5 comments, 4 participants]