pytorch - ✅(Solved) Fix [CUDA]When my data size is between 32 and 128 (exclusive), the argsort doesn't go through the WarpMergeSort branch. [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183499Fetched 2026-05-14 03:28:36
View on GitHub
Comments
0
Participants
1
Timeline
5
Reactions
0
Author
Participants
Timeline (top)
closed ×2cross-referenced ×1referenced ×1reopened ×1

Fix Action

Fixed

PR fix notes

PR #183527: Fix CUDA version check for warp merge sort

Description (problem / solution / changelog)

As referenced in the original code and PR #96223, CUDA 11.6 is explicitly required. The correct value of the CUDA_VERSION macro for CUDA 11.6 is 11060, following the standard encoding rule: CUDA_VERSION = major * 1000 + minor * 10 + patch.This aligns with the correct usage already implemented in pytorch/third_party/kineto/libkineto/src/CuptiCbidRegistry.cpp, which defines 11060 for CUDA 11.6.

This PR fixes the incorrect version check that existed previously. Performance benchmarks were conducted on the issue case before and after the fix to verify the performance improvement.

Before fix: sortCommon(MediumRadixSort{}, key, value, dim, descending) 8.544us After fix: sortCommon(WarpMergeSort<128>{}, key, value, dim, descending) 4.096us A 2x speedup is achieved, which is consistent with the conclusion stated in PR #96223 that WarpMergeSort is used under eligible conditions.This results in up to a 2x speedup for unstable sorts and up to 15x speedup for stable sorts, depending on the input geometry.

fixs #183499

before

NameSelf CPU %Self CPUCPU total %CPU totalCPU time avgSelf CUDASelf CUDA %CUDA totalCUDA time avg# of Calls
void at::native::radixSortKVInPlace<-2, -1, 32, 4, ...>0.00%0.000us0.00%0.000us0.000us8.544us63.42%8.544us8.544us1
Memcpy DtoD (Device -> Device)0.00%0.000us0.00%0.000us0.000us3.713us27.56%3.713us1.238us3
void (anonymous namespace)::elementwise_kernel_with_...0.00%0.000us0.00%0.000us0.000us1.216us9.03%1.216us1.216us1
cudaMemcpyAsync3.00%63.241us98.03%2.068ms689.455us0.000us0.00%0.000us0.000us3
Activity Buffer Request95.03%2.005ms95.03%2.005ms2.005ms0.000us0.00%0.000us0.000us1
cudaLaunchKernel1.41%29.713us1.41%29.713us14.857us0.000us0.00%0.000us0.000us2
cudaDeviceSynchronize0.56%11.893us0.56%11.893us5.947us0.000us0.00%0.000us0.000us2

Self CPU time total: 2.110ms Self CUDA time total: 13.473us

after

NameSelf CPU %Self CPUCPU total %CPU totalCPU time avgSelf CUDASelf CUDA %CUDA totalCUDA time avg# of Calls
void at::native::warpMergeSortKVInPlace<-2, -1, 128,...>0.00%0.000us0.00%0.000us0.000us4.096us45.88%4.096us4.096us1
Memcpy DtoD (Device -> Device)0.00%0.000us0.00%0.000us0.000us3.615us40.50%3.615us1.205us3
void (anonymous namespace)::elementwise_kernel_with_...0.00%0.000us0.00%0.000us0.000us1.216us13.62%1.216us1.216us1
cudaMemcpyAsync3.39%72.069us96.96%2.062ms687.416us0.000us0.00%0.000us0.000us3
Activity Buffer Request93.57%1.990ms93.57%1.990ms1.990ms0.000us0.00%0.000us0.000us1
cudaLaunchKernel1.42%30.218us1.42%30.218us15.109us0.000us0.00%0.000us0.000us2
cudaDeviceGetAttribute0.10%2.151us0.10%2.151us0.538us0.000us0.00%0.000us0.000us4
cudaFuncGetAttributes0.64%13.692us0.64%13.692us13.692us0.000us0.00%0.000us0.000us1
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags0.30%6.370us0.30%6.370us0.398us0.000us0.00%0.000us0.000us16
cudaDeviceSynchronize0.57%12.169us0.57%12.169us6.085us0.000us0.00%0.000us0.000us2

Self CPU time total: 2.127ms Self CUDA time total: 8.927us

Changed files

  • aten/src/ATen/native/cuda/SortUtils.cuh (modified, +1/-1)

Code Example

void sortKeyValueInplace(
    const TensorBase& key,
    const TensorBase& value,
    int64_t dim,
    bool descending,
    bool stable) {
  const auto sort_size = key.size(dim);
  if (sort_size <= 1) {
    return; // Already sorted
  } else if (!stable && sort_size <= 32) {
    // NOTE: Bitonic sort is unstable
    sortCommon(SmallBitonicSort{}, key, value, dim, descending);
#if HAS_WARP_MERGE_SORT()
  } else if (sort_size <= 128) {
   printf("HAS_WARP_SORT\n");
    sortCommon(WarpMergeSort<128, C10_WARP_SIZE>{}, key, value, dim, descending);
#endif
  } else {
printf("HAS_WARP_SOPT():%d\n",HAS_WARP_SORT());
printf("CUDA_VRESION:%d\n",CUDA_VERSION);
    sortCommon(MediumRadixSort{}, key, value, dim, descending);
  }
}

---

import torch
device = "cuda"
input = torch.tensor([[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1]],dtype=torch.int64,device=device)
input.view(-1).argsort()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

In sortKeyValueInplace in ATen/native/cuda/Sort.cu, the WarpMergeSort branch is not taken. My CUDA version is 12020. I found that the macro HAS_WARP_MERGE_SORT() checks CUDA_VERSION >= 110600, so no matter what CUDA version I have, HAS_WARP_MERGE_SORT() would always evaluate to 0. my torch version 2.7.1

void sortKeyValueInplace(
    const TensorBase& key,
    const TensorBase& value,
    int64_t dim,
    bool descending,
    bool stable) {
  const auto sort_size = key.size(dim);
  if (sort_size <= 1) {
    return; // Already sorted
  } else if (!stable && sort_size <= 32) {
    // NOTE: Bitonic sort is unstable
    sortCommon(SmallBitonicSort{}, key, value, dim, descending);
#if HAS_WARP_MERGE_SORT()
  } else if (sort_size <= 128) {
   printf("HAS_WARP_SORT\n");
    sortCommon(WarpMergeSort<128, C10_WARP_SIZE>{}, key, value, dim, descending);
#endif
  } else {
printf("HAS_WARP_SOPT():%d\n",HAS_WARP_SORT());
printf("CUDA_VRESION:%d\n",CUDA_VERSION);
    sortCommon(MediumRadixSort{}, key, value, dim, descending);
  }
}
import torch
device = "cuda"
input = torch.tensor([[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1],[0,6,7,8,9,10,11,1]],dtype=torch.int64,device=device)
input.view(-1).argsort()

result: HAS_WARP_MERGE_SORT:0 CUDA_VERSION 12020

Versions

Around line 383 of aten/src/ATen/native/cuda/Sort.cu, add a printf, compile, and run argsort (to see the printed output)

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [CUDA]When my data size is between 32 and 128 (exclusive), the argsort doesn't go through the WarpMergeSort branch. [1 pull requests, 1 participants]