pytorch - ✅(Solved) Fix [DTensor] RNG tracker does not advance state for CPU tensors on CUDA mesh, causing trunc_normal_ infinite loop [3 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180088Fetched 2026-04-11 06:08:19
View on GitHub
Comments
1
Participants
1
Timeline
42
Reactions
0
Author
Participants
Timeline (top)
mentioned ×17subscribed ×17labeled ×4cross-referenced ×2

Root Cause

Root cause

Fix Action

Fixed

PR fix notes

PR #180087: [DTensor] Error on random ops with device mismatch

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #180087

error on buggy case: https://github.com/pytorch/pytorch/issues/180088

we can have DTensor with cuda mesh and cpu local tensor, for example

mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",))
with torch.device("meta"):
    t = torch.empty(256, 256)
dt = DTensor.from_local(t, mesh, [Shard(0)])
dt = torch.empty_like(dt, device="cpu")
print(f"mesh={dt.device_mesh.device_type}, local={dt._local_tensor.device}", flush=True)

it gets stuck in the infiniate loop, when calling nn.init.trunc_normal_(dt), because cpu RNG does not advance correctly

# torch/nn/init.py:117-125
  while True:                                                                                                                                                                                                                    
      mask = (result < lo) | (result > hi)                                                                                                                                                                                         
      if not mask.any():
          break                                                                                                                                                                                                                    
      result = torch.where(                                                                                                                                                                                                      
          mask,                                                                                                                                                                                                                    
          torch.empty_like(result).normal_(mean, std, generator=generator),                                                                                                                                                      
          result,                                                                                                                                                                                                                  
      )

there are multiple aspect of fixing - but raising error for the illegal case is a good starter

Repro

  # torchrun --nproc_per_node=1 repro.py
  import torch, torch.distributed as dist
  from torch.distributed.device_mesh import init_device_mesh
  from torch.distributed.tensor import DTensor, Shard

  dist.init_process_group("nccl")
  torch.cuda.set_device(0)

  mesh = init_device_mesh("cuda", (1,))
  with torch.device("meta"):
      t = torch.empty(4)
  dt = DTensor.from_local(t, mesh, [Shard(0)])
  dt = torch.empty_like(dt, device="cpu")

  # RNG doesn't advance: two normal_() calls produce identical values.
  # This makes trunc_normal_ (init.py:123) loop forever.
  torch.manual_seed(42)
  a = dt.normal_(0, 1).to_local().clone()
  b = torch.empty_like(dt).normal_(0, 1).to_local()
  assert not torch.equal(a, b), f"BUG: RNG not advancing, a == b == {a}"

Changed files

  • test/distributed/tensor/test_dtensor.py (modified, +15/-0)
  • torch/distributed/tensor/_dispatch.py (modified, +22/-0)

PR #2928: Skip parallelize_fn for seed checkpoint creation

Description (problem / solution / changelog)

Follow up on CI timeout https://github.com/pytorch/torchtitan/actions/runs/24217297179/job/70700751227 — TestGraphTrainerNumerics hangs during seed checkpoint creation

PR #2900 applies fully_shard when world size = 1. This triggered a pytorch side DTensor/nn.init bug: https://github.com/pytorch/pytorch/issues/180088

Beyond pytorch side fix, it also makes sense to skip parallelize_fn for seed checkpoints — nothing from it is needed (AC, compile, nD parallelism, mixed precision, etc.). Seed checkpoints only initialize weights and save.

Also adds a seed checkpoint test with a 30s timeout

Changed files

  • tests/integration_tests/__init__.py (modified, +1/-0)
  • tests/integration_tests/features.py (modified, +12/-0)
  • tests/integration_tests/run_tests.py (modified, +14/-3)
  • torchtitan/trainer.py (modified, +13/-11)

Code Example

with torch.device("meta"):
      model = model_config.build()                                                                                                                                                                                                 
  fully_shard(model, mesh=cuda_mesh)   # DTensors with CUDA mesh
  model.to_empty(device="cpu")          # local tensors now on CPU                                                                                                                                                                 
  model.init_weights()                  # trunc_normal_ → infinite loop

---

# torchrun --nproc_per_node=1 repro.py                                                                                                                                                                                         
  import torch, torch.distributed as dist                                                                                                                                                                                          
  from torch.distributed.device_mesh import init_device_mesh
  from torch.distributed.tensor import DTensor, Shard                                                                                                                                                                              
                                                                                                                                                                                                                                 
  dist.init_process_group("nccl")                                                                                                                                                                                                  
  torch.cuda.set_device(0)                                                                                                                                                                                                       
                                                                                                                                                                                                                                   
  mesh = init_device_mesh("cuda", (1,))                                                                                                                                                                                            
  with torch.device("meta"):
      t = torch.empty(4)                                                                                                                                                                                                           
  dt = DTensor.from_local(t, mesh, [Shard(0)])
  dt = torch.empty_like(dt, device="cpu")

  torch.manual_seed(42)
  cpu_before = torch.random.get_rng_state()
  cuda_before = torch.cuda.get_rng_state()
  dt.normal_(0, 1)
  cpu_after = torch.random.get_rng_state()
  cuda_after = torch.cuda.get_rng_state()

  print(f"CUDA RNG advances: {not torch.equal(cuda_before, cuda_after)}")
  print(f"CPU  RNG advances: {not torch.equal(cpu_before, cpu_after)}")

  a = dt.normal_(0, 1).to_local().clone()
  b = torch.empty_like(dt).normal_(0, 1).to_local()
  assert not torch.equal(a, b), f"BUG: RNG not advancing, a == b == {a}"

  dist.destroy_process_group()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Bug

DTensor's RNG tracker does not advance CPU RNG state when a DTensor has a CUDA mesh but its local tensor is on CPU. This causes nn.init.trunc_normal_ to loop forever since its rejection sampling (introduced in #174997) depends on normal_() producing different values each call.

Root cause

The DTensor RNG tracker advances the CUDA RNG but the local tensor is on CPU, so values come from the stale CPU RNG:

CUDA RNG advances: True (1st), True (2nd)
CPU RNG advances: False (1st), False (2nd)

How to get a CPU DTensor on a CUDA mesh

This happens in FSDP seed checkpoint creation:

with torch.device("meta"):
    model = model_config.build()                                                                                                                                                                                                 
fully_shard(model, mesh=cuda_mesh)   # DTensors with CUDA mesh
model.to_empty(device="cpu")          # local tensors now on CPU                                                                                                                                                                 
model.init_weights()                  # trunc_normal_ → infinite loop

Repro

  # torchrun --nproc_per_node=1 repro.py                                                                                                                                                                                         
  import torch, torch.distributed as dist                                                                                                                                                                                          
  from torch.distributed.device_mesh import init_device_mesh
  from torch.distributed.tensor import DTensor, Shard                                                                                                                                                                              
                                                                                                                                                                                                                                 
  dist.init_process_group("nccl")                                                                                                                                                                                                  
  torch.cuda.set_device(0)                                                                                                                                                                                                       
                                                                                                                                                                                                                                   
  mesh = init_device_mesh("cuda", (1,))                                                                                                                                                                                            
  with torch.device("meta"):
      t = torch.empty(4)                                                                                                                                                                                                           
  dt = DTensor.from_local(t, mesh, [Shard(0)])
  dt = torch.empty_like(dt, device="cpu")

  torch.manual_seed(42)
  cpu_before = torch.random.get_rng_state()
  cuda_before = torch.cuda.get_rng_state()
  dt.normal_(0, 1)
  cpu_after = torch.random.get_rng_state()
  cuda_after = torch.cuda.get_rng_state()

  print(f"CUDA RNG advances: {not torch.equal(cuda_before, cuda_after)}")
  print(f"CPU  RNG advances: {not torch.equal(cpu_before, cpu_after)}")

  a = dt.normal_(0, 1).to_local().clone()
  b = torch.empty_like(dt).normal_(0, 1).to_local()
  assert not torch.equal(a, b), f"BUG: RNG not advancing, a == b == {a}"

  dist.destroy_process_group()

Versions

pytorch nightly

cc @ezyang @gchanan @kadeng @msaroufim @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @dcci @aditvenk @xmfan

extent analysis

TL;DR

The issue can be fixed by advancing the CPU RNG state when a DTensor has a CUDA mesh but its local tensor is on CPU.

Guidance

  • The root cause is that the DTensor RNG tracker advances the CUDA RNG but not the CPU RNG when the local tensor is on CPU.
  • To verify the issue, run the provided repro code and check if the CPU RNG state is advanced after calling dt.normal_(0, 1).
  • To mitigate the issue, ensure that the CPU RNG state is advanced when the local tensor is on CPU, potentially by modifying the DTensor RNG tracker to handle this case.
  • Investigate the torch.distributed.tensor module and the DTensor class to understand how the RNG state is managed and how it can be modified to fix the issue.

Example

No code example is provided as the issue requires a deeper understanding of the PyTorch internals and the DTensor implementation.

Notes

The issue is specific to the PyTorch nightly version and may not be present in other versions. The fix may require modifications to the PyTorch codebase or the DTensor implementation.

Recommendation

Apply a workaround to advance the CPU RNG state when a DTensor has a CUDA mesh but its local tensor is on CPU, as the root cause is identified and a fix would require modifying the PyTorch codebase.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [DTensor] RNG tracker does not advance state for CPU tensors on CUDA mesh, causing trunc_normal_ infinite loop [3 pull requests, 1 comments, 1 participants]