pytorch - ✅(Solved) Fix Wrong ROCM code execution in in `ScaledBlas.cpp` at `check_swizzle_lengths()` [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178687Fetched 2026-04-08 01:45:08
View on GitHub
Comments
0
Participants
1
Timeline
48
Reactions
0
Participants
Timeline (top)
mentioned ×17subscribed ×17labeled ×5referenced ×5

Fix Action

Fixed

PR fix notes

PR #178688: [ROCm] Fix wrong ROCM code execution in in ScaledBlas.cpp at check_swizzle_lengths()

Description (problem / solution / changelog)

This fixes #178687

In line 1231 we got

#ifdef ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

while #ifdef ROCM is not defined, so // ROCM doesn't swizzle their formats - we don't care what's passed. return; will never be executed, instead the other part is executed always which causes overhead, so the part should be changed to

#ifdef USE_ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

Contributed by Benedikt Johannes

cc @jerryzh168 @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

Changed files

  • aten/src/ATen/native/cuda/ScaledBlas.cpp (modified, +1/-1)

Code Example

#ifdef ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

---

#ifdef USE_ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

In line 1231 we got

#ifdef ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

while #ifdef ROCM is not defined, so // ROCM doesn't swizzle their formats - we don't care what's passed. return; will never be executed, instead the other part is executed always which causes overhead, so the part should be changed to

#ifdef USE_ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

Versions

Latest

cc @jerryzh168 @jbschlosser @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

extent analysis

Fix Plan

To fix the issue, we need to update the preprocessor directive to correctly check for the USE_ROCM definition.

  • Replace #ifdef ROCM with #ifdef USE_ROCM in the affected code block.
  • No other changes are required.

Example:

#ifdef USE_ROCM
  // ROCM doesn't swizzle their formats - we don't care what's passed.
  return;
#else
  // Store implementations that care about swizzling, and how many swizzle arguments
  // they have to have
  // NOTE(slayton): auto here is unable to deduce the correct type..
  std::array<std::tuple<ScaledGemmImplementation, unsigned int>, 4> swizzled_impl = {{
    // {implementation, # required arguments}
    {ScaledGemmImplementation::MXFP8_MXFP8, 1},
    {ScaledGemmImplementation::NVFP4_NVFP4, 2},
    {ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE, 1},
    {ScaledGemmImplementation::MXFP4_MXFP4, 1}
  }};

Verification

To verify the fix, compile and run the code with and without the USE_ROCM definition. The code should behave as expected in both cases, with the return statement being executed when USE_ROCM is defined.

Extra Tips

  • Make sure to update any other occurrences of #ifdef ROCM to #ifdef USE_ROCM to maintain consistency throughout the codebase.
  • Consider adding a comment to explain the purpose of the USE_ROCM definition and its impact on the code behavior.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Wrong ROCM code execution in in `ScaledBlas.cpp` at `check_swizzle_lengths()` [1 pull requests, 1 participants]