pytorch - 💡(How to fix) Fix opcheck stride mismatch in onnx.RotaryEmbedding.opset23 fake impl after #183002 [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): test_faketensor failed with Stride mismatch! Strides are (96, 8, 24, 1) and (96, 32, 8, 1) (mismatched at 1)!

Root Cause

The real impl for 4D input does:

  1. permute(0, 2, 1, 3) → shape (2, 4, 3, 8)
  2. Rotation compute via cat → new contiguous tensor with strides (96, 24, 8, 1)
  3. permute(0, 2, 1, 3) back → shape (2, 3, 4, 8) with strides (96, 8, 24, 1)

The fake impl (x.clone()) returns strides (96, 32, 8, 1) — standard contiguous for (2, 3, 4, 8).

This was always wrong but was hidden because opcheck only validated strides on CUDA tensors before #183002.

Fix Action

Fixed

Code Example

torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): test_faketensor failed with
Stride mismatch! Strides are (96, 8, 24, 1) and (96, 32, 8, 1) (mismatched at 1)!

---

python test/onnx/ops/test_ops.py NativeOnnxOpsTest.test_rotary_embedding_opcheck

---

import torch
from torch.onnx.ops import _impl

x = torch.rand(2, 3, 4, 8)
pos = torch.randint(0, 50, (2, 4)).long()
sin = torch.rand(50, 4)
cos = torch.rand(50, 4)

real_out = _impl.rotary_embedding_23(x, cos, sin, pos)
fake_out = _impl._rotary_embedding_23_fake_impl(x, cos, sin, pos)

print("Real strides:", real_out.stride())  # (96, 8, 24, 1)
print("Fake strides:", fake_out.stride())  # (96, 32, 8, 1)
assert real_out.stride() == fake_out.stride(), "Stride mismatch!"

---

def _rotary_embedding_23_fake_impl(x, cos_cache, sin_cache, position_ids=None, *, interleaved=False, num_heads=0, rotary_embedding_dim=0):
    if x.dim() == 4:
        return torch.empty_permuted(x.shape, (0, 2, 1, 3), dtype=x.dtype, device=x.device)
    return torch.empty_like(x)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug [auto generated by Claude]

After #183002 enabled opcheck stride checking for CPU tensors, test_rotary_embedding_opcheck fails with a stride mismatch:

torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): test_faketensor failed with
Stride mismatch! Strides are (96, 8, 24, 1) and (96, 32, 8, 1) (mismatched at 1)!

The fake implementation in torch/onnx/ops/_impl.py returns x.clone(), which preserves the input's contiguous strides. But the real implementation permutes axes (0,2,1,3), computes (producing a new contiguous tensor in the permuted layout), then permutes back — resulting in different strides.

Repro

Requires a build that includes #183002 (merged 2026-05-08):

python test/onnx/ops/test_ops.py NativeOnnxOpsTest.test_rotary_embedding_opcheck

To verify the stride mismatch on any build:

import torch
from torch.onnx.ops import _impl

x = torch.rand(2, 3, 4, 8)
pos = torch.randint(0, 50, (2, 4)).long()
sin = torch.rand(50, 4)
cos = torch.rand(50, 4)

real_out = _impl.rotary_embedding_23(x, cos, sin, pos)
fake_out = _impl._rotary_embedding_23_fake_impl(x, cos, sin, pos)

print("Real strides:", real_out.stride())  # (96, 8, 24, 1)
print("Fake strides:", fake_out.stride())  # (96, 32, 8, 1)
assert real_out.stride() == fake_out.stride(), "Stride mismatch!"

Root cause

The real impl for 4D input does:

  1. permute(0, 2, 1, 3) → shape (2, 4, 3, 8)
  2. Rotation compute via cat → new contiguous tensor with strides (96, 24, 8, 1)
  3. permute(0, 2, 1, 3) back → shape (2, 3, 4, 8) with strides (96, 8, 24, 1)

The fake impl (x.clone()) returns strides (96, 32, 8, 1) — standard contiguous for (2, 3, 4, 8).

This was always wrong but was hidden because opcheck only validated strides on CUDA tensors before #183002.

Suggested fix

def _rotary_embedding_23_fake_impl(x, cos_cache, sin_cache, position_ids=None, *, interleaved=False, num_heads=0, rotary_embedding_dim=0):
    if x.dim() == 4:
        return torch.empty_permuted(x.shape, (0, 2, 1, 3), dtype=x.dtype, device=x.device)
    return torch.empty_like(x)

Versions

Affects current main (post #183002, merged 2026-05-08).

cc @justinchuby @aorenste

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix opcheck stride mismatch in onnx.RotaryEmbedding.opset23 fake impl after #183002 [1 pull requests]