pytorch - 💡(How to fix) Fix torch.topk backward failed on DTensor input [2 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178582Fetched 2026-04-08 01:40:41
View on GitHub
Comments
2
Participants
3
Timeline
38
Reactions
0
Timeline (top)
mentioned ×16subscribed ×16labeled ×4commented ×2

Error Message

[rank0]: Traceback (most recent call last): [rank0]: File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 27, in <module> [rank0]: step() [rank0]: File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 24, in step [rank0]: loss.backward() [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward [rank0]: torch.autograd.backward( [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/autograd/init.py", line 353, in backward [rank0]: _engine_run_backward( [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward [rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner [rank0]: return disable_fn(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn [rank0]: return fn(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 344, in torch_dispatch [rank0]: return DTensor._op_dispatcher.dispatch( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 180, in dispatch [rank0]: op_info = self.unwrap_to_op_info(op_call, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 379, in unwrap_to_op_info [rank0]: self._try_replicate_spec_for_scalar_tensor( [rank0]: File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 481, in _try_replicate_spec_for_scalar_tensor [rank0]: raise RuntimeError( [rank0]: RuntimeError: aten.scatter.src: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Fix Action

Fix / Workaround

error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 27, in <module>
[rank0]:     step()
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 24, in step
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 344, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 180, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 379, in unwrap_to_op_info
[rank0]:     self._try_replicate_spec_for_scalar_tensor(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 481, in _try_replicate_spec_for_scalar_tensor
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: aten.scatter.src: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Code Example

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, init_device_mesh

# torchrun --nproc-per-node=2 topk_dtensor.py
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")

mesh = init_device_mesh("cuda", (world_size,))

x = torch.randn(32, 1000, 4, device=device, requires_grad=True)

x = DTensor.from_local(x, mesh, placements=[Shard(0)])

def step():
    topk, topk_idx = torch.topk(x, 2, dim=-1, sorted=True)
    loss = topk.sum()
    loss.backward()

for _ in range(10):
    step()

---

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 27, in <module>
[rank0]:     step()
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 24, in step
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 344, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 180, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 379, in unwrap_to_op_info
[rank0]:     self._try_replicate_spec_for_scalar_tensor(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 481, in _try_replicate_spec_for_scalar_tensor
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: aten.scatter.src: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.topk backward failed on DTensor input

sample code:

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, init_device_mesh

# torchrun --nproc-per-node=2 topk_dtensor.py
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")

mesh = init_device_mesh("cuda", (world_size,))

x = torch.randn(32, 1000, 4, device=device, requires_grad=True)

x = DTensor.from_local(x, mesh, placements=[Shard(0)])

def step():
    topk, topk_idx = torch.topk(x, 2, dim=-1, sorted=True)
    loss = topk.sum()
    loss.backward()

for _ in range(10):
    step()

error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 27, in <module>
[rank0]:     step()
[rank0]:   File "/data01/wangxudong/workspace/playground/topk_dtensor.py", line 24, in step
[rank0]:     loss.backward()
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_compile.py", line 51, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 344, in __torch_dispatch__
[rank0]:     return DTensor._op_dispatcher.dispatch(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 180, in dispatch
[rank0]:     op_info = self.unwrap_to_op_info(op_call, args, kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 379, in unwrap_to_op_info
[rank0]:     self._try_replicate_spec_for_scalar_tensor(
[rank0]:   File "/usr/local/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 481, in _try_replicate_spec_for_scalar_tensor
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: aten.scatter.src: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Versions

Collecting environment information... PyTorch version: 2.7.0a0+gitbf70a34.aml Is debug build: False CUDA used to build PyTorch: 12.4 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 10 (buster) (x86_64) GCC version: (Debian 8.3.0-6) 8.3.0 Clang version: Could not collect CMake version: version 3.31.2 Libc version: glibc-2.28

Python version: 3.11.10 (main, Mar 10 2026, 22:51:26) [GCC 8.3.0] (64-bit runtime) Python platform: Linux-5.10.135.bsk.6-amd64-x86_64-with-glibc2.28 Is CUDA available: True CUDA runtime version: 12.4.131 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA H20 GPU 1: NVIDIA H20 GPU 2: NVIDIA H20 GPU 3: NVIDIA H20 GPU 4: NVIDIA H20 GPU 5: NVIDIA H20 GPU 6: NVIDIA H20 GPU 7: NVIDIA H20

Nvidia driver version: 535.161.08 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.4.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.4.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 52 bits physical, 57 bits virtual CPU(s): 192 On-line CPU(s) list: 0-191 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 143 Model name: Intel(R) Xeon(R) Platinum 8457C Stepping: 8 CPU MHz: 3100.000 CPU max MHz: 3800.0000 CPU min MHz: 800.0000 BogoMIPS: 5200.00 Virtualization: VT-x L1d cache: 48K L1i cache: 32K L2 cache: 2048K L3 cache: 99840K NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities

Versions of relevant libraries: [pip3] numpy==1.26.3 [pip3] optree==0.19.0 [pip3] torch==2.7.0a0+gitbf70a34 [pip3] triton==3.3.0 [conda] Could not collect

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

Fix Plan

The issue arises from the fact that torch.topk is not compatible with DTensor inputs. To fix this, we need to convert the DTensor to a local tensor before applying torch.topk. Here are the steps:

  • Convert the DTensor to a local tensor using the to_local() method.
  • Apply torch.topk to the local tensor.
  • Compute the loss and perform the backward pass.

Code Changes

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Shard, init_device_mesh

# ... (rest of the code remains the same)

x = DTensor.from_local(x, mesh, placements=[Shard(0)])

def step():
    # Convert DTensor to local tensor
    x_local = x.to_local()
    
    # Apply torch.topk to the local tensor
    topk, topk_idx = torch.topk(x_local, 2, dim=-1, sorted=True)
    
    # Compute the loss
    loss = topk.sum()
    
    # Perform the backward pass
    loss.backward()

for _ in range(10):
    step()

Verification

To verify that the fix worked, run the modified code and check that it no longer raises a RuntimeError. You can also add print statements or use a debugger to inspect the values of x_local, topk, and topk_idx to ensure that they are correct.

Extra Tips

  • When working with DTensor, it's essential to be mindful of the tensor's distribution and placement. In this case, converting the DTensor to a local tensor allows us to apply torch.topk without issues.
  • If you need to perform other operations that are not compatible with DTensor, consider converting the tensor to a local tensor before applying those operations.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.topk backward failed on DTensor input [2 comments, 3 participants]