pytorch - ✅(Solved) Fix Memory leak: DTensorSpec objects from `_try_replicate_spec_for_scalar_tensor` accumulate indefinitely with FSDP2 [1 pull requests, 3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181761Fetched 2026-04-29 06:11:06
View on GitHub
Comments
3
Participants
2
Timeline
39
Reactions
0
Author
Participants
Timeline (top)
mentioned ×15subscribed ×15commented ×3referenced ×3

Root Cause

_try_replicate_spec_for_scalar_tensor in torch/distributed/tensor/_dispatch.py:635-645 creates a new DTensorSpec object on every invocation. With FSDP2, many ops involve mixing plain tensors with DTensors, which triggers this function repeatedly — often with identical arguments, producing equivalent but distinct DTensorSpec objects.

These objects accumulate over training steps and are never freed. Key observations:

  • gc.collect() does not reclaim them (not cyclic references)
  • Clearing both the Python LRU cache and C++ FastMap sharding propagation caches does not free them either
  • Only 7 OpSchema objects are ever live, so the retention is not from OpSchema accumulation
  • The exact retention mechanism is unclear — gc.get_referrers() shows the specs inside tuples with no GC-visible parent, suggesting they may be held by C++ objects not tracked by Python's GC

Fix Action

Workaround

Monkey-patching _try_replicate_spec_for_scalar_tensor to cache its return value by (shape, stride, dtype, mesh) eliminates most of the observable memory growth. This works because the leaked references all point to the same shared object instead of accumulating distinct copies:

from torch.distributed.tensor._dispatch import OpDispatcher

_orig = OpDispatcher._try_replicate_spec_for_scalar_tensor
_cache = {}
def _cached(self, op_call, tensor_arg, compute_mesh):
    key = (tensor_arg.shape, tensor_arg.stride(), tensor_arg.dtype, id(compute_mesh))
    if key not in _cache:
        _cache[key] = _orig(self, op_call, tensor_arg, compute_mesh)
    return _cache[key]
OpDispatcher._try_replicate_spec_for_scalar_tensor = _cached

Note that this masks the leak rather than fixing it — the underlying references are still not released, but since they all point to the same object, memory consumption stays bounded.

A secondary, smaller leak also persists at _op_schema.py:405 (_DTensor_OpSchema_post_init), which accumulates ~3.4 KiB/step of objects that are also not freed. This may be related or a separate issue.

PR fix notes

PR #181792: [DTensor] Fix DTensorSpec refcount leak in OpSchema._recompute_comparison_key

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #181792

Fixes #181761.

DTensor_OpSchema_recompute_comparison_key_impl (the no-static_kwargkey branch) was passing args_to_hash_tup.release().ptr() into PyTuple_Pack. That's wrong: .release().ptr() is the idiom for steal-semantics APIs like PyTuple_SET_ITEM. PyTuple_Pack INCREFs each argument instead, so pairing it with .release() strands one reference per call — the wrapper gives up its decref, PyTuple_Pack adds an INCREF, net +1 forever. Every call leaked one ref to args_to_hash_tup, pinning every DTensorSpec in args_schema for the process lifetime.

Introduced in #161695. #178301 and #174879 only sidestepped the buggy path.

Authored with Claude.

Co-authored-by: Lorenzo Porzi [email protected]

Changed files

  • test/distributed/tensor/test_op_schema.py (modified, +16/-0)
  • torch/csrc/autograd/python_variable.cpp (modified, +1/-1)

Code Example

"""torchrun --nproc-per-node=1 repro_dtensor_leak.py"""
import gc, tracemalloc
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard

tracemalloc.start()

def main():
    dist.init_process_group(backend="nccl")

    model = nn.Sequential(*[nn.Linear(256, 256) for _ in range(40)]).cuda()
    for layer in model:
        fully_shard(layer)
    fully_shard(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, fused=True)

    for step in range(100):
        x = torch.randn(4, 256, device="cuda")
        loss = model(x).sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step == 10:
            gc.collect()
            snap1 = tracemalloc.take_snapshot()
        if step == 60:
            gc.collect()
            snap2 = tracemalloc.take_snapshot()
            print("Memory growth (step 10 → 60):")
            for s in snap2.compare_to(snap1, "lineno")[:6]:
                print(s)
            from torch.distributed.tensor._dtensor_spec import DTensorSpec
            print(f"Live DTensorSpec objects: {sum(1 for o in gc.get_objects() if type(o) is DTensorSpec)}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

---

Memory growth (step 1060):
torch/distributed/tensor/_dispatch.py:639: size=648 KiB (+531 KiB), count=9763 (+8000), average=68 B
torch/distributed/tensor/_dispatch.py:637: size=496 KiB (+406 KiB), count=9760 (+8000), average=52 B
<string>:1: size=346 KiB (+281 KiB), count=4904 (+4000), average=72 B
torch/distributed/tensor/_dispatch.py:641: size=267 KiB (+219 KiB), count=4880 (+4000), average=56 B
torch/distributed/tensor/_op_schema.py:405: size=208 KiB (+170 KiB), count=370 (+300), average=576 B
torch/distributed/tensor/_dtensor_spec.py:398: size=153 KiB (+125 KiB), count=4886 (+4000), average=32 B
Live DTensorSpec objects: 4886

---

from torch.distributed.tensor._dispatch import OpDispatcher

_orig = OpDispatcher._try_replicate_spec_for_scalar_tensor
_cache = {}
def _cached(self, op_call, tensor_arg, compute_mesh):
    key = (tensor_arg.shape, tensor_arg.stride(), tensor_arg.dtype, id(compute_mesh))
    if key not in _cache:
        _cache[key] = _orig(self, op_call, tensor_arg, compute_mesh)
    return _cache[key]
OpDispatcher._try_replicate_spec_for_scalar_tensor = _cached
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When using FSDP2 (fully_shard), DTensorSpec objects created by OpDispatcher._try_replicate_spec_for_scalar_tensor accumulate indefinitely during training and are never freed. The function creates a new DTensorSpec on every call, and these objects are retained by an unknown mechanism and not released by gc.collect().

The leak scales linearly with both model size (number of parameters) and training steps, and has caused notable, accumulating CPU memory consumption for us when training large models over a large number of steps.

Affected versions: PyTorch 2.10, 2.11 Not affected: PyTorch 2.8

Reproducer

Minimal self-contained script — requires 1 GPU:

"""torchrun --nproc-per-node=1 repro_dtensor_leak.py"""
import gc, tracemalloc
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard

tracemalloc.start()

def main():
    dist.init_process_group(backend="nccl")

    model = nn.Sequential(*[nn.Linear(256, 256) for _ in range(40)]).cuda()
    for layer in model:
        fully_shard(layer)
    fully_shard(model)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, fused=True)

    for step in range(100):
        x = torch.randn(4, 256, device="cuda")
        loss = model(x).sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if step == 10:
            gc.collect()
            snap1 = tracemalloc.take_snapshot()
        if step == 60:
            gc.collect()
            snap2 = tracemalloc.take_snapshot()
            print("Memory growth (step 10 → 60):")
            for s in snap2.compare_to(snap1, "lineno")[:6]:
                print(s)
            from torch.distributed.tensor._dtensor_spec import DTensorSpec
            print(f"Live DTensorSpec objects: {sum(1 for o in gc.get_objects() if type(o) is DTensorSpec)}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Output (PyTorch 2.10):

Memory growth (step 10 → 60):
torch/distributed/tensor/_dispatch.py:639: size=648 KiB (+531 KiB), count=9763 (+8000), average=68 B
torch/distributed/tensor/_dispatch.py:637: size=496 KiB (+406 KiB), count=9760 (+8000), average=52 B
<string>:1: size=346 KiB (+281 KiB), count=4904 (+4000), average=72 B
torch/distributed/tensor/_dispatch.py:641: size=267 KiB (+219 KiB), count=4880 (+4000), average=56 B
torch/distributed/tensor/_op_schema.py:405: size=208 KiB (+170 KiB), count=370 (+300), average=576 B
torch/distributed/tensor/_dtensor_spec.py:398: size=153 KiB (+125 KiB), count=4886 (+4000), average=32 B
Live DTensorSpec objects: 4886

DTensorSpec count grows by ~80/step and is never reclaimed.

Potential root cause and workaround

Note that the following has been derived by an LLM-driven analysis and might be incorrect.

Root cause

_try_replicate_spec_for_scalar_tensor in torch/distributed/tensor/_dispatch.py:635-645 creates a new DTensorSpec object on every invocation. With FSDP2, many ops involve mixing plain tensors with DTensors, which triggers this function repeatedly — often with identical arguments, producing equivalent but distinct DTensorSpec objects.

These objects accumulate over training steps and are never freed. Key observations:

  • gc.collect() does not reclaim them (not cyclic references)
  • Clearing both the Python LRU cache and C++ FastMap sharding propagation caches does not free them either
  • Only 7 OpSchema objects are ever live, so the retention is not from OpSchema accumulation
  • The exact retention mechanism is unclear — gc.get_referrers() shows the specs inside tuples with no GC-visible parent, suggesting they may be held by C++ objects not tracked by Python's GC

Workaround

Monkey-patching _try_replicate_spec_for_scalar_tensor to cache its return value by (shape, stride, dtype, mesh) eliminates most of the observable memory growth. This works because the leaked references all point to the same shared object instead of accumulating distinct copies:

from torch.distributed.tensor._dispatch import OpDispatcher

_orig = OpDispatcher._try_replicate_spec_for_scalar_tensor
_cache = {}
def _cached(self, op_call, tensor_arg, compute_mesh):
    key = (tensor_arg.shape, tensor_arg.stride(), tensor_arg.dtype, id(compute_mesh))
    if key not in _cache:
        _cache[key] = _orig(self, op_call, tensor_arg, compute_mesh)
    return _cache[key]
OpDispatcher._try_replicate_spec_for_scalar_tensor = _cached

Note that this masks the leak rather than fixing it — the underlying references are still not released, but since they all point to the same object, memory consumption stays bounded.

A secondary, smaller leak also persists at _op_schema.py:405 (_DTensor_OpSchema_post_init), which accumulates ~3.4 KiB/step of objects that are also not freed. This may be related or a separate issue.

Versions

PyTorch version: 2.10.0+cu128 Is debug build: False CUDA used to build PyTorch: 12.8 ROCM used to build PyTorch: N/A

OS: CentOS Stream 9 (x86_64) GCC version: (GCC) 11.5.0 20240719 (Red Hat 11.5.0-14) Clang version: Could not collect CMake version: version 4.3.2 Libc version: glibc-2.34

Python version: 3.12.13 | packaged by conda-forge | (main, Mar 5 2026, 16:50:00) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-6.16.1-0_fbk2_0_gf40efc324cc8-x86_64-with-glibc2.34 Is CUDA available: True CUDA runtime version: 12.8.93 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA PG509-210 Nvidia driver version: 580.126.09 cuDNN version: Could not collect Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 22 On-line CPU(s) list: 0-21 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8339HC CPU @ 1.80GHz CPU family: 6 Model: 85 Thread(s) per core: 1 Core(s) per socket: 22 Socket(s): 1 Stepping: 11 BogoMIPS: 3591.79 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx512_bf16 arat vnmi umip pku ospke avx512_vnni md_clear flush_l1d arch_capabilities Virtualization: VT-x Hypervisor vendor: KVM Virtualization type: full L1d cache: 704 KiB (22 instances) L1i cache: 704 KiB (22 instances) L2 cache: 88 MiB (22 instances) L3 cache: 16 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-21 Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Vulnerable Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable Vulnerability Old microcode: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Vulnerable Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Vulnerable; BHI: Vulnerable Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries: [pip3] flash_attn==2.8.3+cu128torch2.10 [pip3] numpy==2.4.3 [pip3] nvidia-cublas-cu12==12.8.4.1 [pip3] nvidia-cuda-cupti-cu12==12.8.90 [pip3] nvidia-cuda-nvrtc-cu12==12.8.93 [pip3] nvidia-cuda-runtime-cu12==12.8.90 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cufft-cu12==11.3.3.83 [pip3] nvidia-curand-cu12==10.3.9.90 [pip3] nvidia-cusolver-cu12==11.7.3.90 [pip3] nvidia-cusparse-cu12==12.5.8.93 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.8.93 [pip3] nvidia-nvtx-cu12==12.8.90 [pip3] open_clip_torch==3.3.0 [pip3] torch==2.10.0+cu128 [pip3] torch_fidelity==0.4.0 [pip3] torchaudio==2.11.0+cu128 [pip3] torchmetrics==1.9.0 [pip3] torchvision==0.25.0+cu128 [pip3] triton==3.6.0 [pip3] vmaf-torch==1.1.0 [conda] cuda-cudart 12.8.90 0 nvidia [conda] cuda-cudart-dev 12.8.90 0 nvidia [conda] cuda-cudart-dev_linux-64 12.8.90 0 nvidia [conda] cuda-cudart-static 12.8.90 0 nvidia [conda] cuda-cudart-static_linux-64 12.8.90 0 nvidia [conda] cuda-cudart_linux-64 12.8.90 0 nvidia [conda] cuda-cupti 12.8.90 0 nvidia [conda] cuda-cupti-dev 12.8.90 0 nvidia [conda] cuda-libraries 12.8.2 0 nvidia [conda] cuda-libraries-dev 12.8.2 0 nvidia [conda] cuda-nvrtc 12.8.93 0 nvidia [conda] cuda-nvrtc-dev 12.8.93 0 nvidia [conda] cuda-nvtx 12.8.90 0 nvidia [conda] cuda-opencl 12.8.90 0 nvidia [conda] cuda-opencl-dev 12.8.90 0 nvidia [conda] flash-attn 2.8.3+cu128torch2.10 pypi_0 pypi [conda] libcublas 12.8.5.5 h2bcd275_0 nvidia [conda] libcublas-dev 12.8.5.5 h2bcd275_0 nvidia [conda] libcufft 11.3.3.83 0 nvidia [conda] libcufft-dev 11.3.3.83 0 nvidia [conda] libcurand 10.3.9.90 0 nvidia [conda] libcurand-dev 10.3.9.90 0 nvidia [conda] libcusolver 11.7.3.90 0 nvidia [conda] libcusolver-dev 11.7.3.90 0 nvidia [conda] libcusparse 12.5.8.93 0 nvidia [conda] libcusparse-dev 12.5.8.93 0 nvidia [conda] libnvjitlink 12.8.93 1 nvidia [conda] libnvjitlink-dev 12.8.93 1 nvidia [conda] libopenvino-pytorch-frontend 2026.0.0 hecca717_1 conda-forge [conda] numpy 2.4.3 pypi_0 pypi [conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi [conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi [conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi [conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi [conda] open-clip-torch 3.3.0 pypi_0 pypi [conda] tbb 2022.3.0 hb700be7_2 conda-forge [conda] torch 2.10.0+cu128 pypi_0 pypi [conda] torch-fidelity 0.4.0 pypi_0 pypi [conda] torchaudio 2.11.0+cu128 pypi_0 pypi [conda] torchmetrics 1.9.0 pypi_0 pypi [conda] torchvision 0.25.0+cu128 pypi_0 pypi [conda] triton 3.6.0 pypi_0 pypi [conda] vmaf-torch 1.1.0 pypi_0 pypi

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy

extent analysis

TL;DR

Apply a monkey patch to cache the return value of _try_replicate_spec_for_scalar_tensor to prevent memory growth due to accumulating DTensorSpec objects.

Guidance

  • Identify the source of the memory leak by analyzing the provided reproducer script and the output of tracemalloc.
  • Apply the suggested monkey patch to cache the return value of _try_replicate_spec_for_scalar_tensor by (shape, stride, dtype, mesh).
  • Verify the effectiveness of the patch by monitoring memory usage during training.
  • Note that this workaround masks the leak rather than fixing it, and a secondary, smaller leak may still persist.

Example

from torch.distributed.tensor._dispatch import OpDispatcher

_orig = OpDispatcher._try_replicate_spec_for_scalar_tensor
_cache = {}
def _cached(self, op_call, tensor_arg, compute_mesh):
    key = (tensor_arg.shape, tensor_arg.stride(), tensor_arg.dtype, id(compute_mesh))
    if key not in _cache:
        _cache[key] = _orig(self, op_call, tensor_arg, compute_mesh)
    return _cache[key]
OpDispatcher._try_replicate_spec_for_scalar_tensor = _cached

Notes

The provided workaround may not completely eliminate the memory leak, as a secondary leak persists at _op_schema.py:405. Further investigation is needed to address this issue.

Recommendation

Apply the workaround by monkey-patching _try_replicate_spec_for_scalar_tensor to cache its return value, as this will prevent the majority of the memory growth due to accumulating DTensorSpec objects.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Memory leak: DTensorSpec objects from `_try_replicate_spec_for_scalar_tensor` accumulate indefinitely with FSDP2 [1 pull requests, 3 comments, 2 participants]