pytorch - ✅(Solved) Fix [MKLDNN] Convolution dtype checks do not match eager semantics across CUDA and MKLDNN backends [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180548Fetched 2026-04-17 08:26:31
View on GitHub
Comments
0
Participants
1
Timeline
90
Reactions
0
Participants
Timeline (top)
mentioned ×42subscribed ×42labeled ×5cross-referenced ×1

Error Message

import torch import torch.nn.functional as F

def run_case(name, x, weight): print(f"\n[{name}]") try: out = F.conv2d(x, weight) print("success", out.dtype, out.layout) except Exception as e: print(type(e).name) print(e)

weight = torch.randn(1, 1, 1, 1, dtype=torch.float32)

run_case( "cpu_dense_bf16_input_fp32_weight", torch.randn(1, 1, 4, 4, dtype=torch.bfloat16), weight, )

if torch.backends.mkldnn.is_available(): run_case( "cpu_mkldnn_bf16_input_fp32_weight", torch.randn(1, 1, 4, 4, dtype=torch.bfloat16).to_mkldnn(), weight, )

if torch.cuda.is_available(): run_case( "cuda_fp16_input_fp32_weight", torch.randn(1, 1, 4, 4, device="cuda", dtype=torch.float16), weight.cuda(), )

Root Cause

I found an interesting thing when I fixed a conv bug in #179890. If we use bfloat16 for input_tensor and fp32 for weight, the CPU and CUDA backends will reject it because they execute a strict check. However, MKLDNN seems to behave differently. I'm not sure whether this is a specific optimization for mkldnn backend, which causes its different behavior with cpu/cuda. Feel free to correct me if I'm wrong :)

import torch
import torch.nn.functional as F

Fix Action

Fixed

PR fix notes

PR #179890: [inductor] [BugFix] add conv dtype match check for input_tensor and weight

Description (problem / solution / changelog)

Fixes #142463 Align the inductor semantics with eager's on conv dtype check.

specifically, eager conv has checked input_tensor and bias dtype match, but not checked weight. https://github.com/pytorch/pytorch/blob/c9db239ce48b7d136204412c88d18c51f8b71f20/aten/src/ATen/native/Convolution.cpp#L994-L1007

However, when bias=Flase, eager would call this method to do further strict check https://github.com/pytorch/pytorch/blob/c9db239ce48b7d136204412c88d18c51f8b71f20/aten/src/ATen/native/Convolution.cpp#L827-L843

Sadly, in inductor semantics, when setting bias=Flase, this will get into a fake tensor status (maybe?) As a result, this case will pass but should be rejected.

Following previous fixes (#143762, #144313), we can add dtype check in meta_registration.

Test

python test/inductor/test_torchinductor.py -k test_convolution_errors_on_input_weight_dtype_mismatch

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_torchinductor.py (modified, +17/-0)
  • torch/_inductor/kernel/conv.py (modified, +5/-0)

Code Example

import torch
import torch.nn.functional as F


def run_case(name, x, weight):
    print(f"\n[{name}]")
    try:
        out = F.conv2d(x, weight)
        print("success", out.dtype, out.layout)
    except Exception as e:
        print(type(e).__name__)
        print(e)


weight = torch.randn(1, 1, 1, 1, dtype=torch.float32)

run_case(
    "cpu_dense_bf16_input_fp32_weight",
    torch.randn(1, 1, 4, 4, dtype=torch.bfloat16),
    weight,
)

if torch.backends.mkldnn.is_available():
    run_case(
        "cpu_mkldnn_bf16_input_fp32_weight",
        torch.randn(1, 1, 4, 4, dtype=torch.bfloat16).to_mkldnn(),
        weight,
    )

if torch.cuda.is_available():
    run_case(
        "cuda_fp16_input_fp32_weight",
        torch.randn(1, 1, 4, 4, device="cuda", dtype=torch.float16),
        weight.cuda(),
    )

---

[cpu_dense_bf16_input_fp32_weight]
RuntimeError
Input type (CPUBFloat16Type) and weight type (torch.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

[cpu_mkldnn_bf16_input_fp32_weight]
success torch.bfloat16 torch._mkldnn

[cuda_fp16_input_fp32_weight]
RuntimeError
Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

I found an interesting thing when I fixed a conv bug in #179890. If we use bfloat16 for input_tensor and fp32 for weight, the CPU and CUDA backends will reject it because they execute a strict check. However, MKLDNN seems to behave differently. I'm not sure whether this is a specific optimization for mkldnn backend, which causes its different behavior with cpu/cuda. Feel free to correct me if I'm wrong :)

import torch
import torch.nn.functional as F


def run_case(name, x, weight):
    print(f"\n[{name}]")
    try:
        out = F.conv2d(x, weight)
        print("success", out.dtype, out.layout)
    except Exception as e:
        print(type(e).__name__)
        print(e)


weight = torch.randn(1, 1, 1, 1, dtype=torch.float32)

run_case(
    "cpu_dense_bf16_input_fp32_weight",
    torch.randn(1, 1, 4, 4, dtype=torch.bfloat16),
    weight,
)

if torch.backends.mkldnn.is_available():
    run_case(
        "cpu_mkldnn_bf16_input_fp32_weight",
        torch.randn(1, 1, 4, 4, dtype=torch.bfloat16).to_mkldnn(),
        weight,
    )

if torch.cuda.is_available():
    run_case(
        "cuda_fp16_input_fp32_weight",
        torch.randn(1, 1, 4, 4, device="cuda", dtype=torch.float16),
        weight.cuda(),
    )
[cpu_dense_bf16_input_fp32_weight]
RuntimeError
Input type (CPUBFloat16Type) and weight type (torch.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

[cpu_mkldnn_bf16_input_fp32_weight]
success torch.bfloat16 torch._mkldnn

[cuda_fp16_input_fp32_weight]
RuntimeError
Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

Versions

nightly20260409

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal

extent analysis

TL;DR

To fix the issue, ensure that the input tensor and weight have the same data type or use a compatible combination, such as input as a MKLDNN tensor and weight as a dense tensor.

Guidance

  • The error occurs because the input tensor and weight have different data types (bfloat16 and fp32), which is not allowed by the CPU and CUDA backends.
  • To resolve this, you can either change the data type of the input tensor to match the weight (fp32) or use a MKLDNN tensor for the input and a dense tensor for the weight.
  • The MKLDNN backend seems to allow this specific combination, but it's unclear if this is an optimization or a bug.
  • Verify that the fix works by running the run_case function with the modified input tensor and weight.

Example

# Change the input tensor data type to match the weight
x = torch.randn(1, 1, 4, 4, dtype=torch.float32)
run_case("cpu_dense_fp32_input_fp32_weight", x, weight)

Notes

  • The behavior of the MKLDNN backend may be specific to this version (nightly20260409) and may change in future versions.
  • It's unclear if this is a bug or an intended optimization in the MKLDNN backend.

Recommendation

Apply workaround: Use a compatible combination of input tensor and weight data types, such as input as a MKLDNN tensor and weight as a dense tensor, or change the input tensor data type to match the weight. This is because the issue seems to be related to the data type mismatch, and using a compatible combination or changing the data type can resolve the error.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING