pytorch - 💡(How to fix) Fix `torch.compiler.is_compiling` doesn't work inside wrapper subclasses

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fix / Workaround

def torch_dispatch(self, func, types, args, kwargs): print(torch.compiler.is_compiling())

Code Example

from __future__ import annotations
from typing import Any

import torch

class SubTensor(torch.Tensor):
    _data: torch.Tensor

    def __new__(cls, data: torch.Tensor) -> SubTensor:
        ret = torch.Tensor._make_wrapper_subclass(
            cls,
            data.shape,
            dtype=data.dtype,
            device=data.device,
            storage_offset=data.storage_offset(),
            strides=data.stride(),
            pin_memory=data.is_pinned(),
            layout=data.layout,
            requires_grad=data.requires_grad,
        )
        ret._data = data.detach()
        return ret

    def __torch_dispatch__(self, func, types, args, kwargs):
        print(torch.compiler.is_compiling())

    @staticmethod
    def __tensor_unflatten__(
        inner_tensors: dict[str, torch.Tensor],
        meta: Any,
        outer_size: tuple[int, ...],
        outer_stride: tuple[int, ...],
    ) -> SubTensor:
        if meta is not None:
            raise AssertionError(f"meta must be None, got {meta}")
        data = inner_tensors["_data"]
        return SubTensor(data)

    def __tensor_flatten__(self) -> tuple[list[str], Any]:
        return ["_data"], None

def f():
    a = SubTensor(torch.zeros(()))
    return torch.add(a, a)

f()
torch.compile(f)()

---

False
False
False
False
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

As title. Find below a minimal repro.

<details><summary>Repro</summary>
from __future__ import annotations
from typing import Any

import torch

class SubTensor(torch.Tensor):
    _data: torch.Tensor

    def __new__(cls, data: torch.Tensor) -> SubTensor:
        ret = torch.Tensor._make_wrapper_subclass(
            cls,
            data.shape,
            dtype=data.dtype,
            device=data.device,
            storage_offset=data.storage_offset(),
            strides=data.stride(),
            pin_memory=data.is_pinned(),
            layout=data.layout,
            requires_grad=data.requires_grad,
        )
        ret._data = data.detach()
        return ret

    def __torch_dispatch__(self, func, types, args, kwargs):
        print(torch.compiler.is_compiling())

    @staticmethod
    def __tensor_unflatten__(
        inner_tensors: dict[str, torch.Tensor],
        meta: Any,
        outer_size: tuple[int, ...],
        outer_stride: tuple[int, ...],
    ) -> SubTensor:
        if meta is not None:
            raise AssertionError(f"meta must be None, got {meta}")
        data = inner_tensors["_data"]
        return SubTensor(data)

    def __tensor_flatten__(self) -> tuple[list[str], Any]:
        return ["_data"], None

def f():
    a = SubTensor(torch.zeros(()))
    return torch.add(a, a)

f()
torch.compile(f)()
</details> <details><summary>Output</summary>
False
False
False
False
</details>

Expected

First/last False makes sense (running after compilation and not-compiled). But the middle two are run while compiling/tracing and should be True.

Versions

Collecting environment information... PyTorch version: 2.13.0a0+gitfe6a386 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.5 (arm64) GCC version: Could not collect Clang version: 19.1.7 CMake version: version 4.3.2 Libc version: N/A

Python version: 3.10.20 | packaged by conda-forge | (main, Mar 5 2026, 17:06:34) [Clang 19.1.7 ] (64-bit runtime) Python platform: macOS-26.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: False Caching allocator config: N/A

CPU: Apple M3 Max

Versions of relevant libraries: [pip3] flake8==7.3.0 [pip3] numpy==2.2.6 [pip3] onnx==1.21.0 [pip3] onnx-ir==0.1.12 [pip3] optree==0.19.1 [pip3] torch==2.13.0a0+gitfe6a386 [conda] Could not collect

cc @Chillee @ezyang @albanD @samdow @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @amjames @jataylo @azahed98

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compiler.is_compiling` doesn't work inside wrapper subclasses