pytorch - ✅(Solved) Fix `torch.compile` returns different output stride/contiguity from eager for upsample_nearest3d on non-contiguous input [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179272Fetched 2026-04-08 02:43:43
View on GitHub
Comments
1
Participants
2
Timeline
78
Reactions
0
Timeline (top)
mentioned ×36subscribed ×36labeled ×5commented ×1

PR fix notes

PR #179830: Fix upsample_nearest3d stride mismatch under torch.compile (#179272)

Description (problem / solution / changelog)

When F.interpolate(mode="nearest") is applied to a non-contiguous (channels_last_3d) 5D tensor, torch.compile produces a contiguous output while eager mode preserves the input's stride order.

Root cause: the _upsample_nearest decomposition in torch/_decomp/decompositions.py decomposes all upsample_nearest variants into aten._unsafe_index, which always returns a contiguous tensor. A memory format correction (result.contiguous(memory_format=...)) was applied afterward, but only for 4D tensors (if result.ndim == 4). 5D tensors got no correction, so the contiguous strides from _unsafe_index propagated through to the compiled output.

The 4D-only guard was there because the code was written when only upsample_nearest2d existed. When 1d/3d lowerings were added in PR #87158, the decomposition was not updated to handle 5D.

Fix: remove the ndim == 4 guard so memory format preservation applies unconditionally. For ndim != 4 or 5, suggest_memory_format returns contiguous_format, making the .contiguous() call a no-op. The 4D-specific CUDA heuristic (skip channels_last when n_channels < 4) is kept only for 4D to match eager behavior.

Also register upsample_nearest1d/3d and their _exact variants in needs_realized_inputs for consistency with upsample_nearest2d.

Fixes #179272

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo

Changed files

  • test/inductor/test_torchinductor.py (modified, +11/-0)
  • torch/_decomp/decompositions.py (modified, +9/-5)
  • torch/_inductor/lowering.py (modified, +4/-0)

Code Example

import torch
import torch.nn.functional as F

def fn(x):
    return F.interpolate(x, size=(6, 7, 8), mode="nearest")

x = torch.randn(2, 4, 2, 2, 3).permute(0, 4, 1, 2, 3)

eager_out = fn(x)
compiled_out = torch.compile(fn)(x)

print("eager stride:", eager_out.stride())
print("eager contiguous:", eager_out.is_contiguous())

print("compiled stride:", compiled_out.stride())
print("compiled contiguous:", compiled_out.is_contiguous())

---

eager stride: (1008, 1, 168, 24, 3)
eager contiguous: False
compiled stride: (1008, 336, 56, 8, 1)
compiled contiguous: True
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.compile returns an output with different stride and contiguity from eager mode for upsample_nearest3d on a non-contiguous input.

The values are the same, but the output layout differs

Reproducer:

import torch
import torch.nn.functional as F

def fn(x):
    return F.interpolate(x, size=(6, 7, 8), mode="nearest")

x = torch.randn(2, 4, 2, 2, 3).permute(0, 4, 1, 2, 3)

eager_out = fn(x)
compiled_out = torch.compile(fn)(x)

print("eager stride:", eager_out.stride())
print("eager contiguous:", eager_out.is_contiguous())

print("compiled stride:", compiled_out.stride())
print("compiled contiguous:", compiled_out.is_contiguous())

Output

eager stride: (1008, 1, 168, 24, 3)
eager contiguous: False
compiled stride: (1008, 336, 56, 8, 1)
compiled contiguous: True

Versions

PyTorch version: 2.10.0+cpu

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be mitigated by ensuring the input tensor is contiguous before passing it to the compiled function.

Guidance

  • Verify the contiguity of the input tensor x before calling the compiled function torch.compile(fn)(x).
  • Consider adding a .contiguous() call to the input tensor x to ensure it is contiguous, like x = x.contiguous().
  • Check the documentation for torch.compile to see if there are any known issues or limitations related to non-contiguous inputs.
  • Test the compiled function with a contiguous input tensor to see if the issue persists.

Example

x = torch.randn(2, 4, 2, 2, 3).permute(0, 4, 1, 2, 3)
x = x.contiguous()  # Ensure the input tensor is contiguous
compiled_out = torch.compile(fn)(x)

Notes

The issue seems to be related to the non-contiguous input tensor, and making it contiguous may resolve the issue. However, the root cause of the difference in behavior between eager and compiled modes is not explicitly stated.

Recommendation

Apply workaround: Ensure the input tensor is contiguous before passing it to the compiled function, as this may resolve the issue without requiring an upgrade or other changes.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING