pytorch - 💡(How to fix) Fix `torch.compile` crashes on `stft`/`fftn`/`rfftn`/`ifftn` backward — meta kernel stride mismatch [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182200Fetched 2026-05-02 05:26:47
View on GitHub
Comments
0
Participants
1
Timeline
88
Reactions
0
Author
Participants
Timeline (top)
mentioned ×40subscribed ×40labeled ×8

Error Message

import torch

1. stft backward

x = torch.randn(128, device="cuda", requires_grad=True)

@torch.compile(backend="inductor") def f_stft(a): return torch.stft(a, n_fft=32, hop_length=16, return_complex=True).abs().sum()

f_stft(x).backward() # AssertionError: stride mismatch

Root Cause

The meta kernels for _fft_r2c and _fft_c2c (used by these ops) return strides inconsistent with what eager execution produces. During backward, the transposed/mismatched strides cause assert_size_stride failures.

For example, _fft_c2c on a (2, 4, 4, 4) input may produce meta strides like (64, 1, 16, 4) instead of the correct (64, 16, 4, 1).

This is the same family as #106623 (FFT meta strides). The fix for rfft/rfft2/etc. in #145977 appears to have addressed some variants but not these.

Code Example

import torch

# 1. stft backward
x = torch.randn(128, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_stft(a):
    return torch.stft(a, n_fft=32, hop_length=16, return_complex=True).abs().sum()

f_stft(x).backward()  # AssertionError: stride mismatch

---

# 2. fftn backward (4D input)
x = torch.randn(2, 4, 4, 4, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_fftn(a):
    return torch.fft.fftn(a).abs().sum()

f_fftn(x).backward()  # AssertionError: stride mismatch

---

# 3. rfftn backward (3D input)
x = torch.randn(4, 8, 8, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_rfftn(a):
    return torch.fft.rfftn(a).abs().sum()

f_rfftn(x).backward()  # AssertionError: stride mismatch

---

# 4. ifftn backward (4D complex input)
x = torch.randn(2, 8, 8, 8, device="cuda", dtype=torch.cfloat, requires_grad=True)

@torch.compile(backend="inductor")
def f_ifftn(a):
    return torch.fft.ifftn(a).abs().sum()

f_ifftn(x).backward()  # AssertionError: stride mismatch
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

torch.compile crashes with an AssertionError (stride mismatch) in the backward pass for several FFT operations. The forward pass compiles and runs correctly; only backward crashes.

Affected ops and their crash boundaries:

OpCrashes whenMinimal crashing input
torch.stftAlways (any n_fft)(128,) signal, n_fft=32
torch.fft.fftnTransforming >= 3 dims(2, 4, 4, 4) input
torch.fft.rfftnTransforming >= 2 dims(4, 8, 8) input
torch.fft.ifftnTransforming >= 3 dims(2, 8, 8, 8) complex input

Lower-dimensional variants (fft, rfft, fft2, rfft2, irfftn, hfft, ihfft) work correctly.

This is an aot_autograd-level issue — both inductor and aot_eager backends crash.

Related: #106623, #145977. The ops listed in #145977 (rfft, rfft2, etc.) appear to be fixed on current nightly, but stft/fftn/rfftn/ifftn still crash.

Minimal reproducer

import torch

# 1. stft backward
x = torch.randn(128, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_stft(a):
    return torch.stft(a, n_fft=32, hop_length=16, return_complex=True).abs().sum()

f_stft(x).backward()  # AssertionError: stride mismatch
# 2. fftn backward (4D input)
x = torch.randn(2, 4, 4, 4, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_fftn(a):
    return torch.fft.fftn(a).abs().sum()

f_fftn(x).backward()  # AssertionError: stride mismatch
# 3. rfftn backward (3D input)
x = torch.randn(4, 8, 8, device="cuda", requires_grad=True)

@torch.compile(backend="inductor")
def f_rfftn(a):
    return torch.fft.rfftn(a).abs().sum()

f_rfftn(x).backward()  # AssertionError: stride mismatch
# 4. ifftn backward (4D complex input)
x = torch.randn(2, 8, 8, 8, device="cuda", dtype=torch.cfloat, requires_grad=True)

@torch.compile(backend="inductor")
def f_ifftn(a):
    return torch.fft.ifftn(a).abs().sum()

f_ifftn(x).backward()  # AssertionError: stride mismatch

All of these work correctly in eager mode.

Root cause analysis

The meta kernels for _fft_r2c and _fft_c2c (used by these ops) return strides inconsistent with what eager execution produces. During backward, the transposed/mismatched strides cause assert_size_stride failures.

For example, _fft_c2c on a (2, 4, 4, 4) input may produce meta strides like (64, 1, 16, 4) instead of the correct (64, 16, 4, 1).

This is the same family as #106623 (FFT meta strides). The fix for rfft/rfft2/etc. in #145977 appears to have addressed some variants but not these.

Crash boundary details

Tested systematically across 17 FFT variants:

OpStatus
fft, rfft, ifft, irfftOK
fft2, rfft2, ifft2, irfft2OK
hfft, ihfftOK
irfftnOK
fftn (2D transform)OK
ifftn (2D transform)OK
stftCRASH
fftn (3D+ transform)CRASH
rfftn (2D+ transform)CRASH
ifftn (3D+ transform)CRASH

Versions

Versions

  • PyTorch: 2.13.0.dev20260429+cu126
  • Triton: 3.2.0+git4b3bb1e8
  • CUDA: 12.6
  • GPU: Tesla T4 (sm_75)
  • Python: 3.11
  • Also tested: CPU backend (same crash)

cc @mruberry @ezyang @eellison @bdhirsh @bobrenjc93 @aorenste @chauhang @penguinwu

extent analysis

TL;DR

The most likely fix for the torch.compile crash with an AssertionError (stride mismatch) in the backward pass for several FFT operations is to update the PyTorch version, as the issue seems to be related to a known problem with meta kernels for _fft_r2c and _fft_c2c that has been partially addressed in newer versions.

Guidance

  • The issue is likely caused by inconsistent strides returned by the meta kernels for _fft_r2c and _fft_c2c, which is a known problem that has been partially fixed in newer versions of PyTorch.
  • To verify the issue, run the provided minimal reproducers with different input sizes and FFT operations to see if the crash occurs.
  • As a temporary workaround, consider using the eager mode instead of torch.compile for the affected FFT operations.
  • The fix for rfft/rfft2/etc. in #145977 appears to have addressed some variants, but not the ones listed in this issue.

Example

No code example is provided as the issue is related to a specific PyTorch version and the fix is likely to be an update to a newer version.

Notes

The issue is specific to the PyTorch version 2.13.0.dev20260429+cu126 and may be resolved in newer versions. The problem is also specific to the inductor and aot_eager backends.

Recommendation

Apply workaround by using eager mode for affected FFT operations until a newer version of PyTorch is released that addresses the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING