pytorch - 💡(How to fix) Fix [inductor] Silent incorrect gradients with interpolate, conv2d, max_pool2d and torch.where

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Code Example

import torch
import torch.nn.functional as F


def inner(x, weight, bias):
    y = F.interpolate(
        x,
        scale_factor=(1.5, 1.25),
        mode="bilinear",
        align_corners=False,
    )
    y = y[:, :, 1:7, 1:8]

    conv = F.conv2d(y, weight, bias, padding=1)
    pool = F.max_pool2d(conv, kernel_size=3, stride=1, padding=1)
    threshold = pool.mean(dim=(2, 3), keepdim=True)
    return conv, pool, threshold


def make_mask(x, weight, bias):
    conv, _, threshold = inner(x, weight, bias)
    return conv > threshold


def fn(x, weight, bias, mask):
    conv, pool, _ = inner(x, weight, bias)

    return torch.where(mask, conv, pool)


def clone_inputs(xs):
    return [x.detach().clone().requires_grad_(True) for x in xs]


def run(f, xs, mask):
    out = f(*xs, mask)
    out.sum().backward()
    return out.detach(), [x.grad.detach().clone() for x in xs]


def max_diff(a, b):
    return (a - b).abs().max().item()


def main():
    print("torch:", torch.__version__)
    print("cuda:", torch.version.cuda)

    torch.manual_seed(0)
    device = "cpu"

    inputs = [
        torch.randn(2, 3, 6, 7, device=device, requires_grad=True),
        torch.randn(4, 3, 3, 3, device=device, requires_grad=True),
        torch.zeros(4, device=device, requires_grad=True),
    ]

    inputs_eager = clone_inputs(inputs)
    inputs_comp = clone_inputs(inputs)

    torch._dynamo.reset()
    compiled_mask = torch.compile(make_mask, backend="inductor")

    mask_eager = make_mask(*inputs_eager).detach()
    mask_comp = compiled_mask(*inputs_comp).detach()

    print("mask mismatches:", (mask_eager != mask_comp).sum().item(), "/", mask_eager.numel())

    fixed_mask = mask_eager

    inputs_eager = clone_inputs(inputs)
    inputs_comp = clone_inputs(inputs)

    out_eager, grads_eager = run(fn, inputs_eager, fixed_mask)

    torch._dynamo.reset()
    compiled_fn = torch.compile(fn, backend="inductor")
    out_comp, grads_comp = run(compiled_fn, inputs_comp, fixed_mask)

    print("forward max diff:", max_diff(out_eager, out_comp))

    for i, (ge, gc) in enumerate(zip(grads_eager, grads_comp)):
        diff = max_diff(ge, gc)
        mismatches = ((ge - gc).abs() > 1e-3).sum().item()
        print(f"grad[{i}] max diff:", diff)
        print(f"grad[{i}] mismatches > 1e-3:", mismatches, "/", ge.numel())


if __name__ == "__main__":
    main()

---

(torch-nightly) xyt19@Oasis:/tmp$ python bug.py
torch: 2.13.0.dev20260521+cu130
cuda: 13.0
mask mismatches: 0 / 336
forward max diff: 1.9073486328125e-06
grad[0] max diff: 54.59376907348633
grad[0] mismatches > 1e-3: 210 / 252
grad[1] max diff: 43.85358810424805
grad[1] mismatches > 1e-3: 108 / 108
grad[2] max diff: 77.0
grad[2] mismatches > 1e-3: 4 / 4
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When compiling a function containing F.interpolate, F.conv2d, F.max_pool2d, and torch.where using torch.compile(backend="inductor"), the forward pass matches eager mode perfectly, but the gradients in the backward pass are completely incorrect (showing massive max diffs).

This is a silent correctness issue during training, as the forward output does not throw any errors and matches eager mode (max diff ~1.9e-06), but the calculated gradients for input, weight, and bias differ by values up to 77.0.

Reproducer

import torch
import torch.nn.functional as F


def inner(x, weight, bias):
    y = F.interpolate(
        x,
        scale_factor=(1.5, 1.25),
        mode="bilinear",
        align_corners=False,
    )
    y = y[:, :, 1:7, 1:8]

    conv = F.conv2d(y, weight, bias, padding=1)
    pool = F.max_pool2d(conv, kernel_size=3, stride=1, padding=1)
    threshold = pool.mean(dim=(2, 3), keepdim=True)
    return conv, pool, threshold


def make_mask(x, weight, bias):
    conv, _, threshold = inner(x, weight, bias)
    return conv > threshold


def fn(x, weight, bias, mask):
    conv, pool, _ = inner(x, weight, bias)

    return torch.where(mask, conv, pool)


def clone_inputs(xs):
    return [x.detach().clone().requires_grad_(True) for x in xs]


def run(f, xs, mask):
    out = f(*xs, mask)
    out.sum().backward()
    return out.detach(), [x.grad.detach().clone() for x in xs]


def max_diff(a, b):
    return (a - b).abs().max().item()


def main():
    print("torch:", torch.__version__)
    print("cuda:", torch.version.cuda)

    torch.manual_seed(0)
    device = "cpu"

    inputs = [
        torch.randn(2, 3, 6, 7, device=device, requires_grad=True),
        torch.randn(4, 3, 3, 3, device=device, requires_grad=True),
        torch.zeros(4, device=device, requires_grad=True),
    ]

    inputs_eager = clone_inputs(inputs)
    inputs_comp = clone_inputs(inputs)

    torch._dynamo.reset()
    compiled_mask = torch.compile(make_mask, backend="inductor")

    mask_eager = make_mask(*inputs_eager).detach()
    mask_comp = compiled_mask(*inputs_comp).detach()

    print("mask mismatches:", (mask_eager != mask_comp).sum().item(), "/", mask_eager.numel())

    fixed_mask = mask_eager

    inputs_eager = clone_inputs(inputs)
    inputs_comp = clone_inputs(inputs)

    out_eager, grads_eager = run(fn, inputs_eager, fixed_mask)

    torch._dynamo.reset()
    compiled_fn = torch.compile(fn, backend="inductor")
    out_comp, grads_comp = run(compiled_fn, inputs_comp, fixed_mask)

    print("forward max diff:", max_diff(out_eager, out_comp))

    for i, (ge, gc) in enumerate(zip(grads_eager, grads_comp)):
        diff = max_diff(ge, gc)
        mismatches = ((ge - gc).abs() > 1e-3).sum().item()
        print(f"grad[{i}] max diff:", diff)
        print(f"grad[{i}] mismatches > 1e-3:", mismatches, "/", ge.numel())


if __name__ == "__main__":
    main()

Actual Output

(torch-nightly) xyt19@Oasis:/tmp$ python bug.py
torch: 2.13.0.dev20260521+cu130
cuda: 13.0
mask mismatches: 0 / 336
forward max diff: 1.9073486328125e-06
grad[0] max diff: 54.59376907348633
grad[0] mismatches > 1e-3: 210 / 252
grad[1] max diff: 43.85358810424805
grad[1] mismatches > 1e-3: 108 / 108
grad[2] max diff: 77.0
grad[2] mismatches > 1e-3: 4 / 4

Expected Behavior

Compiled function should produce the same gradients as eager mode (max diffs should be within reasonable floating-point tolerances, e.g., < 1e-4).

Versions

PyTorch version: 2.13.0.dev20260521+cu130 Is debug build: False CUDA used to build PyTorch: 13.0 ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64) GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0 Clang version: 18.1.3 (1ubuntu1) CMake version: version 3.28.3 Libc version: glibc-2.39

Python version: 3.10.20 (main, Mar 11 2026, 17:46:40) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39 Is CUDA available: True CUDA runtime version: 12.0.140 Nvidia driver version: 596.49 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_engines_tensor_ir.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.21.1 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.21.1 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A ersions of relevant libraries: [pip3] numpy==2.2.6 [pip3] nvidia-cublas==13.1.1.3 [pip3] nvidia-cuda-cupti==13.0.85 [pip3] nvidia-cuda-nvrtc==13.0.88 [pip3] nvidia-cuda-runtime==13.0.96 [pip3] nvidia-cudnn-cu13==9.20.0.48 [pip3] nvidia-cufft==12.0.0.61 [pip3] nvidia-curand==10.4.0.35 [pip3] nvidia-cusolver==12.0.4.66 [pip3] nvidia-cusparse==12.6.3.3 [pip3] nvidia-cusparselt-cu13==0.8.1 [pip3] nvidia-nccl-cu13==2.29.7 [pip3] nvidia-nvjitlink==13.0.88 [pip3] nvidia-nvtx==13.0.85 [pip3] torch==2.13.0.dev20260521+cu130 [pip3] torchaudio==2.11.0.dev20260525+cu130 [pip3] torchvision==0.28.0.dev20260525+cu130 [pip3] triton==3.7.0+git88b227e2 [conda] numpy 2.2.6 pypi_0 pypi [conda] nvidia-cublas 13.1.1.3 pypi_0 pypi [conda] nvidia-cuda-cupti 13.0.85 pypi_0 pypi [conda] nvidia-cuda-nvrtc 13.0.88 pypi_0 pypi [conda] nvidia-cuda-runtime 13.0.96 pypi_0 pypi [conda] nvidia-cudnn-cu13 9.20.0.48 pypi_0 pypi [conda] nvidia-cufft 12.0.0.61 pypi_0 pypi [conda] nvidia-curand 10.4.0.35 pypi_0 pypi [conda] nvidia-cusolver 12.0.4.66 pypi_0 pypi [conda] nvidia-cusparse 12.6.3.3 pypi_0 pypi [conda] nvidia-cusparselt-cu13 0.8.1 pypi_0 pypi [conda] nvidia-nccl-cu13 2.29.7 pypi_0 pypi [conda] nvidia-nvjitlink 13.0.88 pypi_0 pypi [conda] nvidia-nvtx 13.0.85 pypi_0 pypi [conda] torch 2.13.0.dev20260521+cu130 pypi_0 pypi [conda] torchaudio 2.11.0.dev20260525+cu130 pypi_0 pypi [conda] torchvision 0.28.0.dev20260525+cu130 pypi_0 pypi [conda] triton 3.7.0+git88b227e2 pypi_0 pypi

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING