pytorch - ✅(Solved) Fix # [DTensor] DTensor subclass has silently broken `__torch_dispatch__` due to c++ fast path. [2 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177716Fetched 2026-04-08 00:52:29
View on GitHub
Comments
0
Participants
1
Timeline
60
Reactions
2
Participants
Assignees
Timeline (top)
subscribed ×20mentioned ×19referenced ×10labeled ×5

Root Cause

There are three layers of hardcoding that prevent DTensor subclasses from customizing dispatch:

Fix Action

Fix / Workaround

After the DTensor dispatch logic was moved from Python to C++ in PR #167051, subclasses of DTensor can no longer override __torch_dispatch__ (or _op_dispatcher). The C++ fast path unconditionally intercepts all DTensor subclass instances and dispatches them through the base DTensor's dispatcher, silently ignoring any overridden dispatch logic in the subclass.

This is a regression from the pre-#167051 behavior where __torch_dispatch__ on a DTensor subclass worked as expected through the normal Python dispatch protocol.

class MyDTensor(DTensor): """A DTensor subclass that tries to override dispatch behavior."""

PR fix notes

PR #177741: [DTensor] Restore subclass torch_dispatch fallback

Description (problem / solution / changelog)

Fix #177716

Summary

  1. What is the root cause problem The DTensor C++ fast path started treating DTensor subclasses the same as the exact DTensor type, so subclass __torch_dispatch__ overrides were skipped entirely. At the same time, the Python DTensor fallback and Python-side custom handlers hardcoded the base DTensor, so delegating with super().__torch_dispatch__() or a subclass _op_dispatcher could not preserve subclass behavior.

  2. What is the proposed fix Keep the C++ fast path limited to the exact DTensor type, restore the Python DTensor.__torch_dispatch__ path for subclasses, thread the runtime subclass through Python rewrapping and custom handlers, and add a regression test that exercises a DTensor subclass delegating with super().__torch_dispatch__().

  3. Why the proposed fix is the right long term fix This preserves the C++ optimization for the common base-DTensor case while restoring the expected tensor-subclass dispatch contract for DTensor subclasses without introducing a second subclass-specific C++ dispatch path.

Drafted via @codex, published after manual review by @bobrenjc93

Changed files

  • test/distributed/tensor/test_dtensor.py (modified, +840/-0)
  • torch/csrc/autograd/python_variable.cpp (modified, +15/-9)
  • torch/csrc/utils/python_arg_parser.cpp (modified, +4/-2)
  • torch/distributed/tensor/_api.py (modified, +53/-11)
  • torch/distributed/tensor/_decompositions.py (modified, +1/-5)
  • torch/distributed/tensor/_dispatch.py (modified, +392/-40)
  • torch/distributed/tensor/_nonlinear_redux.py (modified, +26/-8)
  • torch/distributed/tensor/_sharding_prop.py (modified, +268/-24)
  • torch/distributed/tensor/_tp_conv.py (modified, +21/-7)
  • torch/distributed/tensor/debug/__init__.py (modified, +1/-3)
  • torch/distributed/tensor/experimental/_context_parallel/_attention.py (modified, +57/-21)
  • torch/distributed/tensor/experimental/_context_parallel/_sharding_rules.py (modified, +1/-3)
  • torch/distributed/tensor/parallel/loss.py (modified, +30/-10)

PR #177878: [DTensor] Fix DTensor subclass torch_dispatch bypass

Description (problem / solution / changelog)

The C++ fast path introduced in #167051 used isinstance to check for DTensor, which also matched DTensor subclasses. This silently routed subclass instances through the base DTensor C++ dispatcher, bypassing any custom torch_dispatch on the subclass.

Use an exact type check so that only base DTensor hits the C++ fast path. Subclasses fall through to the normal torch_dispatch Python dispatch protocol.

Fixes #177716

Changed files

  • test/distributed/tensor/test_dtensor.py (modified, +67/-0)
  • torch/csrc/utils/python_arg_parser.cpp (modified, +16/-2)
  • torch/distributed/tensor/_api.py (modified, +20/-8)

Code Example

import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Replicate, init_device_mesh

# Single-process setup for demonstration
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "12355")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
dist.init_process_group(backend="gloo")

mesh = init_device_mesh("cpu", (1,))


class MyDTensor(DTensor):
    """A DTensor subclass that tries to override dispatch behavior."""

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"[MyDTensor] dispatching: {func.__name__}")
        # Custom logic here (e.g., logging, modified sharding, etc.)
        return super().__torch_dispatch__(func, types, args, kwargs)


# Create a base DTensor and manually re-wrap as MyDTensor
base = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])
my_dt = MyDTensor(base._local_tensor, base._spec, requires_grad=False)

print(f"type(my_dt) = {type(my_dt)}")       # <class '__main__.MyDTensor'>
print(f"isinstance(my_dt, DTensor) = {isinstance(my_dt, DTensor)}")  # True

# This op goes through C++ dispatchDTensorOp() directly.
# MyDTensor.__torch_dispatch__ is NEVER called.
result = my_dt + my_dt
print(f"type(result) = {type(result)}")      # <class 'torch.distributed.tensor.DTensor'> -- NOT MyDTensor!
# "[MyDTensor] dispatching: ..." is never printed.

dist.destroy_process_group()

---

static bool is_dtensor(PyObject* obj) {
#ifdef USE_DISTRIBUTED
  const py::handle dtensor = get_dtensor_class();
  return (PyObject*)Py_TYPE(obj) == dtensor.ptr() ||
      py::isinstance(py::handle(obj), dtensor);  // <-- matches subclasses too!
#else
  return false;
#endif
}

---

if (!is_torch_function && is_dtensor(arg)) {
      // ... calls dispatchDTensorOp() directly
      ret = dispatchDTensorOp(
          *opt_op, torch_api_function, args, kwargs, opt_stack);
    } else {
      // Normal path -- calls __torch_dispatch__ via Python
      ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
          torch_function.ptr(),  // <-- subclass's __torch_dispatch__ would be called here
          torch_api_function, py_types.ptr(), args, kwargs, NULL));
    }

---

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_op_dispatcher,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded to base DTensor class
        .attr("_op_dispatcher"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_dispatch,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("_dispatch_fast_path_python_tail"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_dispatcher_wrap,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("wrap"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_get_local_results_slow_path,
    py::module::import("torch")
        .attr("distributed").attr("tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("_dispatch_get_local_results_slow_path"))

---

py::object dispatchDTensorOp(...) {
  ...
  const auto op_dispatcher = get_dtensor_op_dispatcher();  // always base DTensor's
  ...
}

---

@staticmethod
def wrap(res, spec):
    if isinstance(res, torch.Tensor):
        if spec is not None:
            return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)  # <-- always base DTensor

---

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
    ...
    return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad)  # <-- always base DTensor

---

static bool is_dtensor(PyObject* obj) {
#ifdef USE_DISTRIBUTED
  const py::handle dtensor = get_dtensor_class();
  return (PyObject*)Py_TYPE(obj) == dtensor.ptr();  // exact match only
#else
  return false;
#endif
}
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

After the DTensor dispatch logic was moved from Python to C++ in PR #167051, subclasses of DTensor can no longer override __torch_dispatch__ (or _op_dispatcher). The C++ fast path unconditionally intercepts all DTensor subclass instances and dispatches them through the base DTensor's dispatcher, silently ignoring any overridden dispatch logic in the subclass.

This is a regression from the pre-#167051 behavior where __torch_dispatch__ on a DTensor subclass worked as expected through the normal Python dispatch protocol.

Reproduction

import os
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Replicate, init_device_mesh

# Single-process setup for demonstration
os.environ.setdefault("MASTER_ADDR", "localhost")
os.environ.setdefault("MASTER_PORT", "12355")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
dist.init_process_group(backend="gloo")

mesh = init_device_mesh("cpu", (1,))


class MyDTensor(DTensor):
    """A DTensor subclass that tries to override dispatch behavior."""

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        print(f"[MyDTensor] dispatching: {func.__name__}")
        # Custom logic here (e.g., logging, modified sharding, etc.)
        return super().__torch_dispatch__(func, types, args, kwargs)


# Create a base DTensor and manually re-wrap as MyDTensor
base = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])
my_dt = MyDTensor(base._local_tensor, base._spec, requires_grad=False)

print(f"type(my_dt) = {type(my_dt)}")       # <class '__main__.MyDTensor'>
print(f"isinstance(my_dt, DTensor) = {isinstance(my_dt, DTensor)}")  # True

# This op goes through C++ dispatchDTensorOp() directly.
# MyDTensor.__torch_dispatch__ is NEVER called.
result = my_dt + my_dt
print(f"type(result) = {type(result)}")      # <class 'torch.distributed.tensor.DTensor'> -- NOT MyDTensor!
# "[MyDTensor] dispatching: ..." is never printed.

dist.destroy_process_group()

Expected behavior: MyDTensor.__torch_dispatch__ is called for my_dt + my_dt, and the result is a MyDTensor.

Actual behavior: MyDTensor.__torch_dispatch__ is never called. The C++ fast path handles the op entirely, and the result is a base DTensor.

Root Cause

There are three layers of hardcoding that prevent DTensor subclasses from customizing dispatch:

1. is_dtensor() matches subclasses and routes them to the C++ fast path

In torch/csrc/utils/python_arg_parser.cpp ([lines 302-310](pytorch/torch/csrc/utils/python_arg_parser.cpp at main · pytorch/pytorch)):

static bool is_dtensor(PyObject* obj) {
#ifdef USE_DISTRIBUTED
  const py::handle dtensor = get_dtensor_class();
  return (PyObject*)Py_TYPE(obj) == dtensor.ptr() ||
      py::isinstance(py::handle(obj), dtensor);  // <-- matches subclasses too!
#else
  return false;
#endif
}

Then in dispatch_on_subclass() ([lines 378-410](pytorch/torch/csrc/utils/python_arg_parser.cpp at main · pytorch/pytorch)):

    if (!is_torch_function && is_dtensor(arg)) {
      // ... calls dispatchDTensorOp() directly
      ret = dispatchDTensorOp(
          *opt_op, torch_api_function, args, kwargs, opt_stack);
    } else {
      // Normal path -- calls __torch_dispatch__ via Python
      ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
          torch_function.ptr(),  // <-- subclass's __torch_dispatch__ would be called here
          torch_api_function, py_types.ptr(), args, kwargs, NULL));
    }

When is_dtensor(arg) returns true (which it does for any DTensor subclass), the code calls dispatchDTensorOp() directly, completely bypassing the normal PyObject_CallFunctionObjArgs(torch_function, ...) path that would invoke the subclass's __torch_dispatch__.

2. dispatchDTensorOp() hardcodes the base DTensor._op_dispatcher

In torch/csrc/autograd/python_variable.cpp ([lines 872-899](pytorch/torch/csrc/autograd/python_variable.cpp at main · pytorch/pytorch)):

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_op_dispatcher,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded to base DTensor class
        .attr("_op_dispatcher"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_dispatch,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("_dispatch_fast_path_python_tail"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_dispatcher_wrap,
    py::module::import("torch.distributed.tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("wrap"))

DEFINE_CACHING_PYTHON_IMPORT_GETTER(
    get_dtensor_get_local_results_slow_path,
    py::module::import("torch")
        .attr("distributed").attr("tensor")
        .attr("DTensor")              // <-- hardcoded
        .attr("_op_dispatcher")
        .attr("_dispatch_get_local_results_slow_path"))

All four cached import getters resolve to torch.distributed.tensor.DTensor._op_dispatcher. Even if a subclass defines its own _op_dispatcher class attribute, it is never consulted. This is also called at the very top of dispatchDTensorOp() ([line 1505](pytorch/torch/csrc/autograd/python_variable.cpp at main · pytorch/pytorch)):

py::object dispatchDTensorOp(...) {
  ...
  const auto op_dispatcher = get_dtensor_op_dispatcher();  // always base DTensor's
  ...
}

3. Output wrapping always constructs base DTensor, not the subclass

In torch/distributed/tensor/_dispatch.py ([line 731](pytorch/torch/distributed/tensor/_dispatch.py at main · pytorch/pytorch)):

@staticmethod
def wrap(res, spec):
    if isinstance(res, torch.Tensor):
        if spec is not None:
            return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)  # <-- always base DTensor

And in torch/distributed/tensor/_api.py __tensor_unflatten__ ([line 384](pytorch/torch/distributed/tensor/_api.py at main · pytorch/pytorch)):

@staticmethod
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
    ...
    return DTensor(local_tensor, unflatten_spec, requires_grad=requires_grad)  # <-- always base DTensor

So even if a subclass somehow made it through dispatch, the output would be downcast to DTensor.

Suggested fix directions

Option A (minimal, preserve C++ fast path for base DTensor): In is_dtensor(), use exact type check (Py_TYPE(obj) == dtensor.ptr()) instead of isinstance. DTensor subclasses would then fall through to the normal __torch_dispatch__ Python path. This preserves the C++ performance optimization for the base DTensor class while restoring correct subclass dispatch behavior.

static bool is_dtensor(PyObject* obj) {
#ifdef USE_DISTRIBUTED
  const py::handle dtensor = get_dtensor_class();
  return (PyObject*)Py_TYPE(obj) == dtensor.ptr();  // exact match only
#else
  return false;
#endif
}

Option B (subclass-aware C++ dispatch): Make dispatchDTensorOp resolve the _op_dispatcher from the actual runtime type of the DTensor argument rather than the hardcoded base class. This allows subclasses to participate in the C++ fast path with custom dispatchers, but requires more invasive changes.

Option C (opt-out protocol): Add a class-level flag (e.g., _use_cpp_dispatch = True) that subclasses can set to False to opt out of the C++ fast path and fall back to Python __torch_dispatch__.

Impact

Any downstream library or user code that extends DTensor via subclassing -- for purposes such as custom sharding strategies, dispatch logging/tracing, modified op semantics, or integration with other tensor subclass frameworks -- is silently broken. The subclass's __torch_dispatch__ override is never called, and outputs are always base DTensor instances.

Versions

Collecting environment information... PyTorch version: 2.10.0 Is debug build: False CUDA used to build PyTorch: 13.1 ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 12 (bookworm) (x86_64) GCC version: (Debian 12.2.0-14+deb12u1) 12.2.0 Clang version: 14.0.6 CMake version: version 3.31.6 Libc version: glibc-2.36

Python version: 3.14.2 (main, Jan 27 2026, 15:50:48) [GCC 12.2.0] (64-bit runtime) Python platform: Linux-5.4.143.bsk.8-amd64-x86_64-with-glibc2.36 Is CUDA available: False CUDA runtime version: 13.1.115 CUDA_MODULE_LOADING set to: N/A GPU models and configuration: N/A Nvidia driver version: Could not collect cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.17.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.17.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 240 On-line CPU(s) list: 0-239 Vendor ID: GenuineIntel Model name: INTEL(R) XEON(R) PLATINUM 8582C CPU family: 6 Model: 207 Thread(s) per core: 2 Core(s) per socket: 60 Socket(s): 2 Stepping: 2 CPU(s) scaling MHz: 75% CPU max MHz: 4000.0000 CPU min MHz: 800.0000 BogoMIPS: 5200.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid cldemote movdiri movdir64b md_clear pconfig flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 5.6 MiB (120 instances) L1i cache: 3.8 MiB (120 instances) L2 cache: 240 MiB (120 instances) L3 cache: 600 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-59,120-179 NUMA node1 CPU(s): 60-119,180-239 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

Fix Plan

To fix the issue, we can use Option A (minimal, preserve C++ fast path for base DTensor). We need to modify the is_dtensor() function to use an exact type check instead of isinstance.

  • Modify torch/csrc/utils/python_arg_parser.cpp to change the is_dtensor() function:
static bool is_dtensor(PyObject* obj) {
#ifdef USE_DISTRIBUTED
  const py::handle dtensor = get_dtensor_class();
  return Py_TYPE(obj) == dtensor.ptr();  // exact match only
#else
  return false;
#endif
}

This change will ensure that only instances of the base DTensor class are routed to the C++ fast path, while subclasses will fall back to the normal Python __torch_dispatch__ path.

Verification

To verify the fix, you can run the reproduction code provided in the issue body. After applying the fix, the output should be:

[MyDTensor] dispatching: add
type(result) = <class '__main__.MyDTensor'>

This indicates that the MyDTensor.__torch_dispatch__ method is being called correctly, and the result is an instance of the MyDTensor subclass.

Extra Tips

  • When modifying the PyTorch codebase, make sure to follow the contribution guidelines and testing procedures to ensure that the changes do not introduce any regressions.
  • If you are using a downstream library or user code that extends DTensor via subclassing, you may need to update your code to work with the fixed DTensor class.
  • Consider adding additional tests to cover the scenarios where DTensor subclasses are used, to ensure that the fix is working correctly in all cases.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix # [DTensor] DTensor subclass has silently broken `__torch_dispatch__` due to c++ fast path. [2 pull requests, 1 participants]