pytorch - ✅(Solved) Fix convolution_backward meta function uses grad_output instead of input to determine grad_input memory format [1 pull requests, 2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178092Fetched 2026-04-08 01:16:44
View on GitHub
Comments
2
Participants
3
Timeline
48
Reactions
0
Timeline (top)
mentioned ×16subscribed ×16labeled ×9unlabeled ×3

Fix Action

Fix / Workaround

Introduced in #174793 . Currently masked by FakeTensor dispatch but would surface if that path changes.

PR fix notes

PR #178208: [meta] fix meta_convolution_backward

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178208

Fixes https://github.com/pytorch/pytorch/issues/178092

Changed files

  • test/test_fake_tensor.py (modified, +24/-0)
  • torch/_meta_registrations.py (modified, +1/-1)

Code Example

import torch
from torch._meta_registrations import meta_convolution_backward

grad_out = torch.empty(2, 64, 8, 8)  # contiguous
inp = torch.empty(2, 32, 8, 8, memory_format=torch.channels_last)
w = torch.empty(64, 32, 3, 3)

result = meta_convolution_backward(
    grad_out, inp, w, [0], [1,1], [1,1], [1,1], False, [0,0], 1, [True,True,False]
)
assert result[0].is_contiguous(memory_format=torch.channels_last)  # FAILS
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Line 3700 in torch/_meta_registrations.py passes grad_output_ to _conv_memory_format when determining grad_input's memory format, but all backends (cuDNN, MKLDNN, MPS) use input_:

  • cuDNN (ConvShared.cpp:596): converts grad_output to input.suggest_memory_format() before use
  • MKLDNN (Convolution.cpp:1492): mkldnn_conv_use_channels_last(input, weight)
  • MPS (Convolution.mm:615): input.suggest_memory_format()

Repro:

import torch
from torch._meta_registrations import meta_convolution_backward

grad_out = torch.empty(2, 64, 8, 8)  # contiguous
inp = torch.empty(2, 32, 8, 8, memory_format=torch.channels_last)
w = torch.empty(64, 32, 3, 3)

result = meta_convolution_backward(
    grad_out, inp, w, [0], [1,1], [1,1], [1,1], False, [0,0], 1, [True,True,False]
)
assert result[0].is_contiguous(memory_format=torch.channels_last)  # FAILS

Fix: line 3700: _conv_memory_format(grad_output_, weight_) to _conv_memory_format(input_, weight_)

Introduced in #174793 . Currently masked by FakeTensor dispatch but would surface if that path changes.

cc @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @mruberry @jbschlosser @walterddr @mikaylagawarecki @pianpwk

Versions

PyTorch: 2.12.0a0+gitae75155 (built from main) Python: 3.12.3 CUDA: 12.8 (RTX 5090)

extent analysis

Fix Plan

To fix the issue, we need to update the meta_convolution_backward function in torch/_meta_registrations.py to use the correct memory format for the input tensor.

  • Update line 3700 in torch/_meta_registrations.py to use input_ instead of grad_output_:
_conv_memory_format(input_, weight_)
  • No other changes are required.

Verification

To verify the fix, run the repro code again:

import torch
from torch._meta_registrations import meta_convolution_backward

grad_out = torch.empty(2, 64, 8, 8)  # contiguous
inp = torch.empty(2, 32, 8, 8, memory_format=torch.channels_last)
w = torch.empty(64, 32, 3, 3)

result = meta_convolution_backward(
    grad_out, inp, w, [0], [1,1], [1,1], [1,1], False, [0,0], 1, [True,True,False]
)
assert result[0].is_contiguous(memory_format=torch.channels_last)  # Should pass

If the assertion passes, the fix is correct.

Extra Tips

  • Make sure to rebuild PyTorch from the updated source code to apply the fix.
  • If you are using a pre-built PyTorch package, you may need to wait for the next release or build PyTorch from source to get the fix.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING