pytorch - 💡(How to fix) Fix ONNX export mismatch for `F.interpolate(size=(1, 1), mode="bilinear", align_corners=False)` [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183523Fetched 2026-05-14 03:28:33
View on GitHub
Comments
0
Participants
1
Timeline
24
Reactions
0
Participants
Timeline (top)
mentioned ×8subscribed ×8labeled ×6added_to_project_v2 ×1

I found a mismatch between PyTorch eager execution and the ONNX model exported by torch.onnx.export for a small F.interpolate case.

The model only applies:

torch.nn.functional.interpolate(
    x,
    size=(1, 1),
    mode="bilinear",
    align_corners=False,
)

For an input of shape (1, 3, 10, 10), PyTorch eager matches a manual half_pixel coordinate transformation. However, the exported ONNX Resize node uses:

coordinate_transformation_mode = "pytorch_half_pixel"

This causes the exported ONNX model to produce a different result from PyTorch eager. If I patch only the ONNX Resize attribute from "pytorch_half_pixel" to "half_pixel", the ONNXRuntime output matches PyTorch eager exactly.

Root Cause

I found a mismatch between PyTorch eager execution and the ONNX model exported by torch.onnx.export for a small F.interpolate case.

The model only applies:

torch.nn.functional.interpolate(
    x,
    size=(1, 1),
    mode="bilinear",
    align_corners=False,
)

For an input of shape (1, 3, 10, 10), PyTorch eager matches a manual half_pixel coordinate transformation. However, the exported ONNX Resize node uses:

coordinate_transformation_mode = "pytorch_half_pixel"

This causes the exported ONNX model to produce a different result from PyTorch eager. If I patch only the ONNX Resize attribute from "pytorch_half_pixel" to "half_pixel", the ONNXRuntime output matches PyTorch eager exactly.

Fix Action

Fix / Workaround

For an input of shape (1, 3, 10, 10), PyTorch eager matches a manual half_pixel coordinate transformation. However, the exported ONNX Resize node uses:

coordinate_transformation_mode = "pytorch_half_pixel"

This causes the exported ONNX model to produce a different result from PyTorch eager. If I patch only the ONNX Resize attribute from "pytorch_half_pixel" to "half_pixel", the ONNXRuntime output matches PyTorch eager exactly.

def patch_resize_coord_mode(model, new_mode): patched = copy.deepcopy(model) for node in patched.graph.node: if node.op_type == "Resize": kept = [a for a in node.attribute if a.name != "coordinate_transformation_mode"] del node.attribute[:] node.attribute.extend(kept) node.attribute.append(helper.make_attribute("coordinate_transformation_mode", new_mode)) return patched

with tempfile.TemporaryDirectory(prefix="torch_onnx_resize_issue_") as tmp: onnx_path = f"{tmp}/model.onnx" patched_path = f"{tmp}/model_half_pixel.onnx"

Code Example

torch.nn.functional.interpolate(
    x,
    size=(1, 1),
    mode="bilinear",
    align_corners=False,
)

---

coordinate_transformation_mode = "pytorch_half_pixel"

---

import copy
import json
import platform
import tempfile

import numpy as np
import torch
import torch.nn.functional as F
import onnx
from onnx import helper

SEED = 0
RTOL = 1e-5
ATOL = 1e-6
OPSET = 18


class Model(torch.nn.Module):
    def forward(self, x):
        return F.interpolate(
            x,
            size=(1, 1),
            mode="bilinear",
            align_corners=False,
        )


def compare(a, b):
    if isinstance(a, torch.Tensor):
        a = a.detach().cpu().numpy()
    if isinstance(b, torch.Tensor):
        b = b.detach().cpu().numpy()

    a = np.asarray(a)
    b = np.asarray(b)

    abs_diff = np.abs(a.astype(np.float64) - b.astype(np.float64))
    denom = np.maximum(np.maximum(np.abs(a), np.abs(b)), 1e-12)
    rel_diff = abs_diff / denom

    return {
        "allclose": bool(np.allclose(a, b, rtol=RTOL, atol=ATOL, equal_nan=False)),
        "shape_a": list(a.shape),
        "shape_b": list(b.shape),
        "dtype_a": str(a.dtype),
        "dtype_b": str(b.dtype),
        "max_abs_diff": float(np.max(abs_diff)) if abs_diff.size else 0.0,
        "max_rel_diff": float(np.max(rel_diff)) if rel_diff.size else 0.0,
        "a_preview": [float(v) for v in a.reshape(-1)[:8]],
        "b_preview": [float(v) for v in b.reshape(-1)[:8]],
    }


def source_index(out_i, in_size, out_size, mode):
    if mode == "half_pixel":
        return (out_i + 0.5) * (in_size / out_size) - 0.5

    if mode == "pytorch_half_pixel":
        # ONNX's pytorch_half_pixel mode has this special case.
        # For output size 1, it uses source coordinate 0.
        if out_size > 1:
            return (out_i + 0.5) * (in_size / out_size) - 0.5
        return 0.0

    raise ValueError(f"unsupported mode: {mode}")


def manual_bilinear_nchw(x, out_h, out_w, mode):
    x = np.asarray(x, dtype=np.float32)
    n, c, in_h, in_w = x.shape
    y = np.empty((n, c, out_h, out_w), dtype=np.float32)

    for oh in range(out_h):
        ih = source_index(oh, in_h, out_h, mode)
        ih = min(max(ih, 0.0), float(in_h - 1))
        h0 = int(np.floor(ih))
        h1 = min(h0 + 1, in_h - 1)
        wh = ih - h0

        for ow in range(out_w):
            iw = source_index(ow, in_w, out_w, mode)
            iw = min(max(iw, 0.0), float(in_w - 1))
            w0 = int(np.floor(iw))
            w1 = min(w0 + 1, in_w - 1)
            ww = iw - w0

            v00 = x[:, :, h0, w0]
            v01 = x[:, :, h0, w1]
            v10 = x[:, :, h1, w0]
            v11 = x[:, :, h1, w1]

            top = v00 * (1.0 - ww) + v01 * ww
            bot = v10 * (1.0 - ww) + v11 * ww
            y[:, :, oh, ow] = top * (1.0 - wh) + bot * wh

    return y


def get_resize_attrs(model):
    attrs_list = []
    for node in model.graph.node:
        if node.op_type == "Resize":
            attrs = {}
            for a in node.attribute:
                v = helper.get_attribute_value(a)
                if isinstance(v, bytes):
                    v = v.decode("utf-8")
                attrs[a.name] = v

            attrs_list.append({
                "name": node.name,
                "op_type": node.op_type,
                "inputs": list(node.input),
                "outputs": list(node.output),
                "attrs": attrs,
            })
    return attrs_list


def patch_resize_coord_mode(model, new_mode):
    patched = copy.deepcopy(model)
    for node in patched.graph.node:
        if node.op_type == "Resize":
            kept = [a for a in node.attribute if a.name != "coordinate_transformation_mode"]
            del node.attribute[:]
            node.attribute.extend(kept)
            node.attribute.append(helper.make_attribute("coordinate_transformation_mode", new_mode))
    return patched


def run_ort(model_path, x):
    import onnxruntime as ort

    sess = ort.InferenceSession(
        model_path,
        providers=["CPUExecutionProvider"],
    )
    return sess.run(None, {sess.get_inputs()[0].name: x.detach().cpu().numpy()})[0]


def main():
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    model = Model().eval()
    x = torch.rand((1, 3, 10, 10), dtype=torch.float32)

    with torch.no_grad():
        torch_out = model(x)

    x_np = x.detach().cpu().numpy()

    manual_half_pixel = manual_bilinear_nchw(
        x_np,
        out_h=1,
        out_w=1,
        mode="half_pixel",
    )

    manual_pytorch_half_pixel = manual_bilinear_nchw(
        x_np,
        out_h=1,
        out_w=1,
        mode="pytorch_half_pixel",
    )

    with tempfile.TemporaryDirectory(prefix="torch_onnx_resize_issue_") as tmp:
        onnx_path = f"{tmp}/model.onnx"
        patched_path = f"{tmp}/model_half_pixel.onnx"

        torch.onnx.export(
            model,
            (x,),
            onnx_path,
            opset_version=OPSET,
            input_names=["input"],
            output_names=["output"],
            do_constant_folding=True,
        )

        exported = onnx.load(onnx_path)
        onnx.checker.check_model(exported)

        resize_attrs = get_resize_attrs(exported)

        patched = patch_resize_coord_mode(exported, "half_pixel")
        onnx.save(patched, patched_path)
        onnx.checker.check_model(onnx.load(patched_path))

        ort_exported = run_ort(onnx_path, x)
        ort_patched = run_ort(patched_path, x)

    print("Python:", platform.python_version())
    print("PyTorch:", torch.__version__)
    print("ONNX:", onnx.__version__)
    print("ONNX opset:", OPSET)
    print()

    print("PyTorch eager:", torch_out.detach().cpu().numpy().reshape(-1).tolist())
    print("manual half_pixel:", manual_half_pixel.reshape(-1).tolist())
    print("manual pytorch_half_pixel:", manual_pytorch_half_pixel.reshape(-1).tolist())
    print("exported Resize attrs:", json.dumps(resize_attrs, indent=2, sort_keys=True))
    print()

    print("torch_vs_manual_half_pixel:")
    print(json.dumps(compare(torch_out, manual_half_pixel), indent=2, sort_keys=True))
    print()

    print("torch_vs_manual_pytorch_half_pixel:")
    print(json.dumps(compare(torch_out, manual_pytorch_half_pixel), indent=2, sort_keys=True))
    print()

    print("torch_vs_ort_exported:")
    print(json.dumps(compare(torch_out, ort_exported), indent=2, sort_keys=True))
    print()

    print("torch_vs_ort_patched_half_pixel:")
    print(json.dumps(compare(torch_out, ort_patched), indent=2, sort_keys=True))


if __name__ == "__main__":
    main()

---

[0.35544848442077637, 0.275476336479187, 0.32302480936050415]

---

[0.35544848442077637, 0.2754763066768646, 0.32302480936050415]

---

[0.49625658988952637, 0.9703746438026428, 0.5210340619087219]

---

torch_vs_manual_half_pixel:
  allclose: true
  max_abs_diff: 2.9802322387695312e-08
  max_rel_diff: 1.0818469117381688e-07

---

torch_vs_manual_pytorch_half_pixel:
  allclose: false
  max_abs_diff: 0.6948983073234558
  max_rel_diff: 0.7161134225440313

---

coordinate_transformation_mode = "pytorch_half_pixel"
mode = "linear"
nearest_mode = "floor"

---

torch_vs_ort_exported:
  allclose: false
  max_abs_diff: 0.6948983073234558
  max_rel_diff: 0.7161134225440313

PyTorch eager:
[0.35544848442077637, 0.275476336479187, 0.32302480936050415]

ONNXRuntime exported:
[0.49625658988952637, 0.9703746438026428, 0.5210340619087219]

---

torch_vs_ort_patched_half_pixel:
  allclose: true
  max_abs_diff: 0.0
  max_rel_diff: 0.0

---

PyTorch version:  2.11.0
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True

ONNX: 1.19.1
ONNX opset: 18
ONNXRuntime: 1.23.2
ONNXRuntime providers: AzureExecutionProvider, CPUExecutionProvider
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Summary

I found a mismatch between PyTorch eager execution and the ONNX model exported by torch.onnx.export for a small F.interpolate case.

The model only applies:

torch.nn.functional.interpolate(
    x,
    size=(1, 1),
    mode="bilinear",
    align_corners=False,
)

For an input of shape (1, 3, 10, 10), PyTorch eager matches a manual half_pixel coordinate transformation. However, the exported ONNX Resize node uses:

coordinate_transformation_mode = "pytorch_half_pixel"

This causes the exported ONNX model to produce a different result from PyTorch eager. If I patch only the ONNX Resize attribute from "pytorch_half_pixel" to "half_pixel", the ONNXRuntime output matches PyTorch eager exactly.

Minimal reproducer

import copy
import json
import platform
import tempfile

import numpy as np
import torch
import torch.nn.functional as F
import onnx
from onnx import helper

SEED = 0
RTOL = 1e-5
ATOL = 1e-6
OPSET = 18


class Model(torch.nn.Module):
    def forward(self, x):
        return F.interpolate(
            x,
            size=(1, 1),
            mode="bilinear",
            align_corners=False,
        )


def compare(a, b):
    if isinstance(a, torch.Tensor):
        a = a.detach().cpu().numpy()
    if isinstance(b, torch.Tensor):
        b = b.detach().cpu().numpy()

    a = np.asarray(a)
    b = np.asarray(b)

    abs_diff = np.abs(a.astype(np.float64) - b.astype(np.float64))
    denom = np.maximum(np.maximum(np.abs(a), np.abs(b)), 1e-12)
    rel_diff = abs_diff / denom

    return {
        "allclose": bool(np.allclose(a, b, rtol=RTOL, atol=ATOL, equal_nan=False)),
        "shape_a": list(a.shape),
        "shape_b": list(b.shape),
        "dtype_a": str(a.dtype),
        "dtype_b": str(b.dtype),
        "max_abs_diff": float(np.max(abs_diff)) if abs_diff.size else 0.0,
        "max_rel_diff": float(np.max(rel_diff)) if rel_diff.size else 0.0,
        "a_preview": [float(v) for v in a.reshape(-1)[:8]],
        "b_preview": [float(v) for v in b.reshape(-1)[:8]],
    }


def source_index(out_i, in_size, out_size, mode):
    if mode == "half_pixel":
        return (out_i + 0.5) * (in_size / out_size) - 0.5

    if mode == "pytorch_half_pixel":
        # ONNX's pytorch_half_pixel mode has this special case.
        # For output size 1, it uses source coordinate 0.
        if out_size > 1:
            return (out_i + 0.5) * (in_size / out_size) - 0.5
        return 0.0

    raise ValueError(f"unsupported mode: {mode}")


def manual_bilinear_nchw(x, out_h, out_w, mode):
    x = np.asarray(x, dtype=np.float32)
    n, c, in_h, in_w = x.shape
    y = np.empty((n, c, out_h, out_w), dtype=np.float32)

    for oh in range(out_h):
        ih = source_index(oh, in_h, out_h, mode)
        ih = min(max(ih, 0.0), float(in_h - 1))
        h0 = int(np.floor(ih))
        h1 = min(h0 + 1, in_h - 1)
        wh = ih - h0

        for ow in range(out_w):
            iw = source_index(ow, in_w, out_w, mode)
            iw = min(max(iw, 0.0), float(in_w - 1))
            w0 = int(np.floor(iw))
            w1 = min(w0 + 1, in_w - 1)
            ww = iw - w0

            v00 = x[:, :, h0, w0]
            v01 = x[:, :, h0, w1]
            v10 = x[:, :, h1, w0]
            v11 = x[:, :, h1, w1]

            top = v00 * (1.0 - ww) + v01 * ww
            bot = v10 * (1.0 - ww) + v11 * ww
            y[:, :, oh, ow] = top * (1.0 - wh) + bot * wh

    return y


def get_resize_attrs(model):
    attrs_list = []
    for node in model.graph.node:
        if node.op_type == "Resize":
            attrs = {}
            for a in node.attribute:
                v = helper.get_attribute_value(a)
                if isinstance(v, bytes):
                    v = v.decode("utf-8")
                attrs[a.name] = v

            attrs_list.append({
                "name": node.name,
                "op_type": node.op_type,
                "inputs": list(node.input),
                "outputs": list(node.output),
                "attrs": attrs,
            })
    return attrs_list


def patch_resize_coord_mode(model, new_mode):
    patched = copy.deepcopy(model)
    for node in patched.graph.node:
        if node.op_type == "Resize":
            kept = [a for a in node.attribute if a.name != "coordinate_transformation_mode"]
            del node.attribute[:]
            node.attribute.extend(kept)
            node.attribute.append(helper.make_attribute("coordinate_transformation_mode", new_mode))
    return patched


def run_ort(model_path, x):
    import onnxruntime as ort

    sess = ort.InferenceSession(
        model_path,
        providers=["CPUExecutionProvider"],
    )
    return sess.run(None, {sess.get_inputs()[0].name: x.detach().cpu().numpy()})[0]


def main():
    torch.manual_seed(SEED)
    np.random.seed(SEED)

    model = Model().eval()
    x = torch.rand((1, 3, 10, 10), dtype=torch.float32)

    with torch.no_grad():
        torch_out = model(x)

    x_np = x.detach().cpu().numpy()

    manual_half_pixel = manual_bilinear_nchw(
        x_np,
        out_h=1,
        out_w=1,
        mode="half_pixel",
    )

    manual_pytorch_half_pixel = manual_bilinear_nchw(
        x_np,
        out_h=1,
        out_w=1,
        mode="pytorch_half_pixel",
    )

    with tempfile.TemporaryDirectory(prefix="torch_onnx_resize_issue_") as tmp:
        onnx_path = f"{tmp}/model.onnx"
        patched_path = f"{tmp}/model_half_pixel.onnx"

        torch.onnx.export(
            model,
            (x,),
            onnx_path,
            opset_version=OPSET,
            input_names=["input"],
            output_names=["output"],
            do_constant_folding=True,
        )

        exported = onnx.load(onnx_path)
        onnx.checker.check_model(exported)

        resize_attrs = get_resize_attrs(exported)

        patched = patch_resize_coord_mode(exported, "half_pixel")
        onnx.save(patched, patched_path)
        onnx.checker.check_model(onnx.load(patched_path))

        ort_exported = run_ort(onnx_path, x)
        ort_patched = run_ort(patched_path, x)

    print("Python:", platform.python_version())
    print("PyTorch:", torch.__version__)
    print("ONNX:", onnx.__version__)
    print("ONNX opset:", OPSET)
    print()

    print("PyTorch eager:", torch_out.detach().cpu().numpy().reshape(-1).tolist())
    print("manual half_pixel:", manual_half_pixel.reshape(-1).tolist())
    print("manual pytorch_half_pixel:", manual_pytorch_half_pixel.reshape(-1).tolist())
    print("exported Resize attrs:", json.dumps(resize_attrs, indent=2, sort_keys=True))
    print()

    print("torch_vs_manual_half_pixel:")
    print(json.dumps(compare(torch_out, manual_half_pixel), indent=2, sort_keys=True))
    print()

    print("torch_vs_manual_pytorch_half_pixel:")
    print(json.dumps(compare(torch_out, manual_pytorch_half_pixel), indent=2, sort_keys=True))
    print()

    print("torch_vs_ort_exported:")
    print(json.dumps(compare(torch_out, ort_exported), indent=2, sort_keys=True))
    print()

    print("torch_vs_ort_patched_half_pixel:")
    print(json.dumps(compare(torch_out, ort_patched), indent=2, sort_keys=True))


if __name__ == "__main__":
    main()

Actual behavior

PyTorch eager output:

[0.35544848442077637, 0.275476336479187, 0.32302480936050415]

Manual half_pixel output:

[0.35544848442077637, 0.2754763066768646, 0.32302480936050415]

Manual pytorch_half_pixel output:

[0.49625658988952637, 0.9703746438026428, 0.5210340619087219]

The PyTorch eager result matches manual half_pixel:

torch_vs_manual_half_pixel:
  allclose: true
  max_abs_diff: 2.9802322387695312e-08
  max_rel_diff: 1.0818469117381688e-07

But it does not match manual pytorch_half_pixel:

torch_vs_manual_pytorch_half_pixel:
  allclose: false
  max_abs_diff: 0.6948983073234558
  max_rel_diff: 0.7161134225440313

The exported ONNX Resize node uses:

coordinate_transformation_mode = "pytorch_half_pixel"
mode = "linear"
nearest_mode = "floor"

As a result, ONNXRuntime output for the exported model differs from PyTorch eager:

torch_vs_ort_exported:
  allclose: false
  max_abs_diff: 0.6948983073234558
  max_rel_diff: 0.7161134225440313

PyTorch eager:
[0.35544848442077637, 0.275476336479187, 0.32302480936050415]

ONNXRuntime exported:
[0.49625658988952637, 0.9703746438026428, 0.5210340619087219]

If I patch only the exported ONNX Resize attribute from "pytorch_half_pixel" to "half_pixel", ONNXRuntime matches PyTorch eager exactly:

torch_vs_ort_patched_half_pixel:
  allclose: true
  max_abs_diff: 0.0
  max_rel_diff: 0.0

Versions

PyTorch version:  2.11.0
Is debug build: True
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 12.3.0-1ubuntu1~22.04) 12.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.16 (main, Dec 11 2024, 16:24:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-59-generic-x86_64-with-glibc2.35
Is CUDA available: True

ONNX: 1.19.1
ONNX opset: 18
ONNXRuntime: 1.23.2
ONNXRuntime providers: AzureExecutionProvider, CPUExecutionProvider

cc @justinchuby @titaiwangms

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix ONNX export mismatch for `F.interpolate(size=(1, 1), mode="bilinear", align_corners=False)` [1 participants]