pytorch - ✅(Solved) Fix [FSDP2] GC-tracked objects grow linearly during training (tuple leak in DTensor dispatch) [1 pull requests, 5 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178276Fetched 2026-04-08 01:20:44
View on GitHub
Comments
5
Participants
3
Timeline
84
Reactions
1
Author
Assignees
Timeline (top)
referenced ×24mentioned ×22subscribed ×22labeled ×6

Fix Action

Fix / Workaround

When training with FSDP2 (fully_shard), Python GC-tracked objects grow linearly throughout training. The growing objects are tuple instances containing DTensorSpec data from the DTensor dispatch infrastructure. This causes gc.collect() time to increase monotonically, eventually dominating iteration time in long training runs.

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8462Y+ CPU family: 6 Model: 143 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 Stepping: 8 CPU max MHz: 4100.0000 CPU min MHz: 800.0000 BogoMIPS: 5600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect user_shstk avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hfi vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 128 MiB (64 instances) L3 cache: 120 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118,120,122,124,126 NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79,81,83,85,87,89,91,93,95,97,99,101,103,105,107,109,111,113,115,117,119,121,123,125,127 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

PR fix notes

PR #178301: Fix unbounded DTensor sharding propagation cache growth

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178301

Fixes https://github.com/pytorch/pytorch/issues/178276

The C++ DTensor dispatch fast path (NativeShardingPropagatorCache) hashes non-tensor items inside list arguments unconditionally into the cache key. For foreach optimizer ops like _foreach_div_.ScalarList and _foreach_addcdiv_.ScalarList, the ScalarList contains step-varying values (AdamW bias corrections change every training step). This causes a new cache entry on every step, leaking OpStrategy, OpSpec, and OutputSharding objects indefinitely.

The fix: apply the same static_argnum filtering to non-tensor items inside lists as is already applied to top-level non-tensor arguments. Since foreach ops register with static_argnum=100 (default), their scalar list values are excluded from the cache key, and the cache stabilizes after warmup.

Authored with Claude.

Changed files

  • test/distributed/tensor/test_optimizers.py (modified, +61/-1)
  • torch/csrc/autograd/python_variable.cpp (modified, +34/-4)

Code Example

"""
Minimal reproduction script for FSDP2 DTensor GC leak.

Usage:
    torchrun --nproc_per_node=1 repro.py
"""

import gc
import os
from collections import Counter

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy


def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))


def cleanup():
    dist.destroy_process_group()


class SimpleModel(nn.Module):
    def __init__(self, dim=1024, depth=8):
        super().__init__()
        self.blocks = nn.ModuleList(
            [nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)) for _ in range(depth)]
        )
        self.head = nn.Linear(dim, dim)

    def forward(self, x):
        for block in self.blocks:
            x = x + block(x)
        return self.head(x).sum()


def main():
    setup()
    rank = dist.get_rank()

    model = SimpleModel(dim=1024, depth=8).cuda()
    mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
    for block in model.blocks:
        fully_shard(block, mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    gc.disable()
    gc.collect()

    prev_count = len(gc.get_objects())
    prev_type_counts = None

    for step in range(200):
        x = torch.randn(4, 1024, device="cuda")
        loss = model(x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (step + 1) % 20 == 0:
            gc.collect()
            objs = gc.get_objects()
            count = len(objs)
            type_counts = Counter(type(o).__qualname__ for o in objs)

            if rank == 0:
                print(f"\nStep {step + 1}: tracked objects = {count} (delta = +{count - prev_count})")
                if prev_type_counts is not None:
                    diff = {
                        k: type_counts[k] - prev_type_counts.get(k, 0)
                        for k in type_counts
                        if type_counts[k] - prev_type_counts.get(k, 0) > 0
                    }
                    top = sorted(diff.items(), key=lambda x: x[1], reverse=True)[:10]
                    print(f"  Top growing types: {top}")

            prev_count = count
            prev_type_counts = type_counts
            del objs

    cleanup()


if __name__ == "__main__":
    main()

---

Step  20: tracked objects = 260939 (delta = +1270)
Step  40: tracked objects = 261408 (delta = +469)
Step  60: tracked objects = 261829 (delta = +421)
Step  80: tracked objects = 262249 (delta = +420)
Step 100: tracked objects = 262669 (delta = +420)
Step 120: tracked objects = 263089 (delta = +420)
Step 140: tracked objects = 263509 (delta = +420)
Step 160: tracked objects = 263929 (delta = +420)
Step 180: tracked objects = 264349 (delta = +420)
Step 200: tracked objects = 264769 (delta = +420)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 The bug

When training with FSDP2 (fully_shard), Python GC-tracked objects grow linearly throughout training. The growing objects are tuple instances containing DTensorSpec data from the DTensor dispatch infrastructure. This causes gc.collect() time to increase monotonically, eventually dominating iteration time in long training runs.

The leak reproduces with a simple FSDP2 training loop — no torch.compile and no prefetching required.

Environment

  • PyTorch: 2.10
  • Python: 3.12
  • CUDA: any
  • OS: Linux

The same repro script does not leak on earlier PyTorch versions.

To Reproduce

Save as repro.py and run with torchrun --nproc_per_node=1 repro.py:

"""
Minimal reproduction script for FSDP2 DTensor GC leak.

Usage:
    torchrun --nproc_per_node=1 repro.py
"""

import gc
import os
from collections import Counter

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy


def setup():
    dist.init_process_group("nccl")
    torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))


def cleanup():
    dist.destroy_process_group()


class SimpleModel(nn.Module):
    def __init__(self, dim=1024, depth=8):
        super().__init__()
        self.blocks = nn.ModuleList(
            [nn.Sequential(nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, dim)) for _ in range(depth)]
        )
        self.head = nn.Linear(dim, dim)

    def forward(self, x):
        for block in self.blocks:
            x = x + block(x)
        return self.head(x).sum()


def main():
    setup()
    rank = dist.get_rank()

    model = SimpleModel(dim=1024, depth=8).cuda()
    mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
    for block in model.blocks:
        fully_shard(block, mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    gc.disable()
    gc.collect()

    prev_count = len(gc.get_objects())
    prev_type_counts = None

    for step in range(200):
        x = torch.randn(4, 1024, device="cuda")
        loss = model(x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if (step + 1) % 20 == 0:
            gc.collect()
            objs = gc.get_objects()
            count = len(objs)
            type_counts = Counter(type(o).__qualname__ for o in objs)

            if rank == 0:
                print(f"\nStep {step + 1}: tracked objects = {count} (delta = +{count - prev_count})")
                if prev_type_counts is not None:
                    diff = {
                        k: type_counts[k] - prev_type_counts.get(k, 0)
                        for k in type_counts
                        if type_counts[k] - prev_type_counts.get(k, 0) > 0
                    }
                    top = sorted(diff.items(), key=lambda x: x[1], reverse=True)[:10]
                    print(f"  Top growing types: {top}")

            prev_count = count
            prev_type_counts = type_counts
            del objs

    cleanup()


if __name__ == "__main__":
    main()

Expected behavior

After warmup, tracked object count should stabilize since model architecture, tensor shapes, and sharding specs are fixed.

Actual behavior

Object count grows linearly, +420 per 20 steps (21/iteration), never stabilizing:

Step  20: tracked objects = 260939 (delta = +1270)
Step  40: tracked objects = 261408 (delta = +469)
Step  60: tracked objects = 261829 (delta = +421)
Step  80: tracked objects = 262249 (delta = +420)
Step 100: tracked objects = 262669 (delta = +420)
Step 120: tracked objects = 263089 (delta = +420)
Step 140: tracked objects = 263509 (delta = +420)
Step 160: tracked objects = 263929 (delta = +420)
Step 180: tracked objects = 264349 (delta = +420)
Step 200: tracked objects = 264769 (delta = +420)

Type breakdown: ('tuple', 420) on every measurement. Tuple content: DTensorSpec objects.

Additional context

In large-scale training (ViT-g, 80+ FSDP-wrapped blocks, 256 GPUs), the leak scales to +1,745 tuples/iteration (~87,250 per GC cycle).

Versions

Collecting environment information... PyTorch version: 2.10.0+cu129 Is debug build: False CUDA used to build PyTorch: 12.9 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (conda-forge gcc 12.4.0-2) 12.4.0 Clang version: Could not collect CMake version: version 3.22.1 Libc version: glibc-2.35

Python version: 3.12.12 | packaged by conda-forge | (main, Oct 13 2025, 14:34:15) [GCC 14.3.0] (64-bit runtime) Python platform: Linux-6.8.12-680-6063-coreweave-amd64-f81899c8-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.5.40 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3 GPU 1: NVIDIA H100 80GB HBM3 GPU 2: NVIDIA H100 80GB HBM3 GPU 3: NVIDIA H100 80GB HBM3 GPU 4: NVIDIA H100 80GB HBM3 GPU 5: NVIDIA H100 80GB HBM3 GPU 6: NVIDIA H100 80GB HBM3 GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 580.95.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.8.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.8.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) Platinum 8462Y+ CPU family: 6 Model: 143 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 Stepping: 8 CPU max MHz: 4100.0000 CPU min MHz: 800.0000 BogoMIPS: 5600.00 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect user_shstk avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hfi vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 3 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 128 MiB (64 instances) L3 cache: 120 MiB (2 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80,82,84,86,88,90,92,94,96,98,100,102,104,106,108,110,112,114,116,118,120,122,124,126 NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39,41,43,45,47,49,51,53,55,57,59,61,63,65,67,69,71,73,75,77,79,81,83,85,87,89,91,93,95,97,99,101,103,105,107,109,111,113,115,117,119,121,123,125,127 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] lovely-numpy==0.2.18 [pip3] mypy==1.19.1 [pip3] mypy_extensions==1.1.0 [pip3] numpy==2.2.0 [pip3] nvidia-cublas-cu12==12.9.1.4 [pip3] nvidia-cuda-cupti-cu12==12.9.79 [pip3] nvidia-cuda-nvrtc-cu12==12.9.86 [pip3] nvidia-cuda-runtime-cu12==12.9.79 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cufft-cu12==11.4.1.4 [pip3] nvidia-curand-cu12==10.3.10.19 [pip3] nvidia-cusolver-cu12==11.7.5.82 [pip3] nvidia-cusparse-cu12==12.5.10.65 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.9.86 [pip3] nvidia-nvtx-cu12==12.9.79 [pip3] nvtx==0.2.14 [pip3] open_clip_torch==3.2.0 [pip3] optree==0.18.0 [pip3] pytorch-triton==3.5.1 [pip3] tbb==2022.3.1 [pip3] tcmlib==1.4.1 [pip3] torch==2.10.0+cu129 [pip3] torchaudio==2.10.0+cu129 [pip3] torchcodec==0.10.0+cu129 [pip3] torchmetrics==1.8.2 [pip3] torchvision==0.25.0+cu129 [pip3] triton==3.6.0 [conda] cuda-cudart 12.5.39 he02047a_0 conda-forge [conda] cuda-cudart-dev 12.5.39 he02047a_0 conda-forge [conda] cuda-cudart-dev_linux-64 12.5.39 h85509e4_0 conda-forge [conda] cuda-cudart-static 12.5.39 he02047a_0 conda-forge [conda] cuda-cudart-static_linux-64 12.5.39 h85509e4_0 conda-forge [conda] cuda-cudart_linux-64 12.5.39 h85509e4_0 conda-forge [conda] cuda-cupti 12.5.39 he02047a_0 conda-forge [conda] cuda-cupti-dev 12.5.39 he02047a_0 conda-forge [conda] cuda-libraries 12.5.0 ha770c72_0 conda-forge [conda] cuda-libraries-dev 12.5.0 ha770c72_0 conda-forge [conda] cuda-nvrtc 12.5.40 he02047a_0 conda-forge [conda] cuda-nvrtc-dev 12.5.40 he02047a_0 conda-forge [conda] cuda-nvtx 12.5.39 he02047a_0 conda-forge [conda] cuda-opencl 12.5.39 he02047a_1 conda-forge [conda] cuda-opencl-dev 12.5.39 he02047a_1 conda-forge [conda] cuda-runtime 12.5.0 ha804496_0 conda-forge [conda] libcublas 12.5.2.13 he02047a_0 conda-forge [conda] libcublas-dev 12.5.2.13 he02047a_0 conda-forge [conda] libcufft 11.2.3.18 he02047a_0 conda-forge [conda] libcufft-dev 11.2.3.18 he02047a_0 conda-forge [conda] libcurand 10.3.6.39 he02047a_0 conda-forge [conda] libcurand-dev 10.3.6.39 he02047a_0 conda-forge [conda] libcusolver 11.6.2.40 he02047a_0 conda-forge [conda] libcusolver-dev 11.6.2.40 he02047a_0 conda-forge [conda] libcusparse 12.4.1.24 he02047a_0 conda-forge [conda] libcusparse-dev 12.4.1.24 he02047a_0 conda-forge [conda] libnvjitlink 12.5.40 he02047a_0 conda-forge [conda] libnvjitlink-dev 12.5.40 he02047a_0 conda-forge [conda] libopenvino-pytorch-frontend 2024.1.0 he02047a_7 conda-forge [conda] lovely-numpy 0.2.18 pypi_0 pypi [conda] numpy 2.2.0 pypi_0 pypi [conda] nvidia-cublas-cu12 12.9.1.4 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.9.79 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.9.86 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.9.79 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi [conda] nvidia-cufft-cu12 11.4.1.4 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.10.19 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.7.5.82 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.5.10.65 pypi_0 pypi [conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi [conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.9.86 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.9.79 pypi_0 pypi [conda] nvtx 0.2.14 pypi_0 pypi [conda] open-clip-torch 3.2.0 pypi_0 pypi [conda] optree 0.18.0 pypi_0 pypi [conda] pytorch-triton 3.5.1 pypi_0 pypi [conda] tbb 2022.3.1 pypi_0 pypi [conda] tcmlib 1.4.1 pypi_0 pypi [conda] torch 2.10.0+cu129 pypi_0 pypi [conda] torchaudio 2.10.0+cu129 pypi_0 pypi [conda] torchcodec 0.10.0+cu129 pypi_0 pypi [conda] torchmetrics 1.8.2 pypi_0 pypi [conda] torchvision 0.25.0+cu129 pypi_0 pypi [conda] triton 3.6.0 pypi_0 pypi

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx

extent analysis

Fix Plan

To address the growing Python GC-tracked objects issue, we need to identify and fix the memory leak. The leak seems to be related to DTensorSpec objects contained in tuples.

Here are the steps to fix the issue:

  • Identify the source of the leak: The issue arises from the fully_shard function, which creates DTensorSpec objects. These objects are not being properly released, causing the memory leak.
  • Release DTensorSpec objects: We need to ensure that these objects are properly released after use. This can be achieved by using a context manager or by manually deleting the objects when they are no longer needed.

Code Changes

We will modify the main function in the repro.py script to release the DTensorSpec objects. We will use a context manager to ensure that the objects are properly released.

def main():
    setup()
    rank = dist.get_rank()

    model = SimpleModel(dim=1024, depth=8).cuda()
    mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
    for block in model.blocks:
        fully_shard(block, mp_policy=mp_policy)
    fully_shard(model, mp_policy=mp_policy)

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    gc.disable()
    gc.collect()

    prev_count = len(gc.get_objects())
    prev_type_counts = None

    for step in range(200):
        x = torch.randn(4, 1024, device="cuda")
        loss = model(x)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Release DTensorSpec objects
        if (step + 1) % 20 == 0:
            gc.collect()
            objs = gc.get_objects()
            count = len(objs)
            type_counts = Counter(type(o).__qualname__ for o in objs)

            # Delete DTensorSpec objects
            for obj in objs:
                if isinstance(obj, tuple) and any(isinstance(item, DTensorSpec) for item in obj):
                    del obj

            if rank == 0:
                print(f"\nStep {step + 1}: tracked objects = {count} (delta = +{count - prev_count})")
                if prev_type_counts is not None:
                    diff = {
                        k: type_counts[k] - prev_type_counts.get(k, 0)
                        for k in type_counts
                        if type_counts[k] - prev_type_counts.get(k, 0) >  }
                    top = sorted(diff.items(), key=lambda x: x[1], reverse=True)[:10]
                    print(f"  Top growing types: {top}")

            prev_count = count
            prev_type_counts = type_counts
            del objs

    cleanup()
``

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

After warmup, tracked object count should stabilize since model architecture, tensor shapes, and sharding specs are fixed.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [FSDP2] GC-tracked objects grow linearly during training (tuple leak in DTensor dispatch) [1 pull requests, 5 comments, 3 participants]