pytorch - ✅(Solved) Fix Update Internal fbsource CK [2 pull requests, 5 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177548Fetched 2026-04-08 00:47:32
View on GitHub
Comments
5
Participants
3
Timeline
70
Reactions
0
Author
Timeline (top)
subscribed ×29mentioned ×18labeled ×9commented ×5

PR fix notes

PR #172246: [ROCm][CK] Enable variable-length attention support for CK SDPA backend

Description (problem / solution / changelog)

Summary

Enables variable-length (varlen) attention support for the Composable Kernel (CK) SDPA backend on ROCm.

Changes

Forward pass (mha_varlen_fwd_ck.hip)

  • Fixed LSE tensor allocation: changed from 3D {batch_size, num_heads, max_seqlen_q} to 2D {num_heads, total_q} to match CK group mode expectation
  • Fixed nhead_stride_lse: use stride(0) instead of stride(1) for 2D layout
  • Fixed batch_stride_lse: set to 0 (no batch dimension in group mode LSE)
  • Fixed min_seqlen_q: changed from -1 to 1 (valid minimum for kernel dispatch)

Backward pass (mha_varlen_bwd_ck.hip)

  • Fixed philox seed/offset access: guarded with if (is_dropout) to avoid dtype mismatch when dropout is disabled

Test infrastructure (test/test_varlen_attention.py)

  • Added sdpa_backend parametrization to test both aotriton and ck backends on ROCm
  • Backend selection: ["aotriton", "ck"] when CK is available, ["aotriton"] otherwise
  • Uses preferred_rocm_fa_library() to switch backends per test

Platform detection

  • Added PLATFORM_SUPPORTS_CK_SDPA in torch/testing/_internal/common_cuda.py
  • Added _is_ck_sdpa_available() in torch/csrc/Module.cpp for runtime CK availability check

Fake tensor implementation (torch/nn/attention/varlen.py)

  • Updated _varlen_attn_fake to use standard [num_heads, total_q] logsumexp format (matches aotriton/CUDA)

Dependencies: CK submodule bump

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

Changed files

  • aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip (modified, +12/-1)
  • aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip (modified, +4/-2)
  • aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip (modified, +15/-4)
  • docs/source/backends.md (modified, +4/-0)
  • test/test_varlen_attention.py (modified, +38/-5)
  • torch/_C/__init__.pyi.in (modified, +1/-0)
  • torch/backends/cuda/__init__.py (modified, +11/-0)
  • torch/csrc/Module.cpp (modified, +8/-0)
  • torch/nn/attention/varlen.py (modified, +11/-9)
  • torch/testing/_internal/common_cuda.py (modified, +8/-0)
RAW_BUFFERClick to expand / collapse

We are attempting to make fixes to the CK integration of pytorch, however, this requires updating the composable_kernel submodule due to API changes. We understand that there's a Meta internal CK instance that is preventing us from making updates due to mismatches in the new API. One such example can be seen in this PR: https://github.com/pytorch/pytorch/pull/172246

This PR was merged and then reverted due to "internal build breakages" citing this comment: https://github.com/pytorch/pytorch/pull/172246#issuecomment-4026784423 The failure in the comment is due to an out of date CK being used against the new CK adjacent pytorch changes.

We are looking to get CK updated internally so we can fix issues/stand up unit tests for pytorch CK.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang @seemethere @malfet @pytorch/pytorch-dev-infra @ZainRizvi @huydhn

extent analysis

Fix Plan

To resolve the issue, we need to update the internal CK instance to match the new API changes in PyTorch. Here are the concrete steps:

  • Update the composable_kernel submodule to the latest version.
  • Modify the CK instance to use the updated API.
  • Example code changes:
# Before
from composable_kernel import old_api

# After
from composable_kernel import new_api
  • Update the unit tests to use the new API.
  • Example test update:
# Before
def test_ck_old_api():
    # Test code using old API

# After
def test_ck_new_api():
    # Test code using new API
  • Verify that the internal build process is working correctly with the updated CK instance.

Verification

To verify that the fix worked, run the following checks:

  • Check that the composable_kernel submodule is updated to the latest version.
  • Run the unit tests to ensure that they are passing with the new API.
  • Verify that the internal build process is working correctly without any breakages.

Extra Tips

  • Make sure to update the documentation to reflect the changes in the API.
  • Consider adding a version check to ensure that the correct version of the composable_kernel submodule is being used.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Update Internal fbsource CK [2 pull requests, 5 comments, 3 participants]