pytorch - 💡(How to fix) Fix Inductor does not wait `AsyncCollectiveTensor` nested inside `DTensor._local_tensor` [3 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180614Fetched 2026-04-17 08:26:03
View on GitHub
Comments
3
Participants
2
Timeline
207
Reactions
0
Participants
Assignees
Timeline (top)
mentioned ×98subscribed ×98labeled ×7commented ×3

Error Message

wait_tensor calls during runtime: 0
ACT.completed after compiled call: False
AssertionError: BUG: Inductor ran the compiled kernel without calling wait_tensor...

Root Cause

With backend='eager' the bug is masked because ACT's __torch_dispatch__ (torch/distributed/_functional_collectives.py:1080) unwraps and waits inline on every op. Inductor bypasses __torch_dispatch__ entirely.

Fix Action

Fix / Workaround

With backend='eager' the bug is masked because ACT's __torch_dispatch__ (torch/distributed/_functional_collectives.py:1080) unwraps and waits inline on every op. Inductor bypasses __torch_dispatch__ entirely.

Compile + warmup. The compile path goes through Python dispatch, which

  # does trigger_wait; clear the counter afterwards to isolate runtime.                                                                
  dt = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])                                                                      
  dt._local_tensor = AsyncCollectiveTensor(dt._local_tensor.clone())                                                                   
  fn(dt)                                                                                                                               
  wait_tensor_calls.clear()

Warmup/compile — Python dispatch here triggers the wait.

  # Clear the counter to isolate the runtime path.                                                                                     
  _ = fn(dt_replicated)
  wait_calls.clear()

Code Example

import os                                                       

  import torch
  import torch.distributed as dist
  import torch.distributed._functional_collectives as fc
  from torch.distributed._functional_collectives import AsyncCollectiveTensor                                                              
  from torch.distributed.device_mesh import init_device_mesh
  from torch.distributed.tensor import DTensor, Replicate                                                                                  
                                                                  
                                                                                                                                           
  def main() -> None:                                             
      os.environ.setdefault("RANK", "0")
      os.environ.setdefault("WORLD_SIZE", "1")                                                                                             
      os.environ.setdefault("MASTER_ADDR", "localhost")
      os.environ.setdefault("MASTER_PORT", "29550")                                                                                        
      dist.init_process_group("gloo", rank=0, world_size=1)       
                                                                                                                                           
      wait_tensor_calls: list[str] = []                           
      orig_wait_tensor = fc.wait_tensor                                                                                                    
                                                                                                                                           
      def counting_wait_tensor(t):
          wait_tensor_calls.append("wait_tensor")                                                                                          
          return orig_wait_tensor(t)                              

      fc.wait_tensor = counting_wait_tensor
                                                                                                                                           
      mesh = init_device_mesh("cpu", (1,))
                                                                                                                                           
      @torch.compile(backend="inductor", fullgraph=True)          
      def fn(x):                                                                                                                           
          return x + 1
                                                                                                                                           
      # Compile + warmup. The compile path goes through Python dispatch, which
      # does trigger_wait; clear the counter afterwards to isolate runtime.                                                                
      dt = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])                                                                      
      dt._local_tensor = AsyncCollectiveTensor(dt._local_tensor.clone())                                                                   
      fn(dt)                                                                                                                               
      wait_tensor_calls.clear()                                                                                                            
                                                                                                                                           
      # Pure runtime call with a fresh ACT.                       
      dt2 = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])                                                                     
      dt2._local_tensor = AsyncCollectiveTensor(dt2._local_tensor.clone())                                                                 
                                                                                                                                           
      assert isinstance(dt2._local_tensor, AsyncCollectiveTensor)                                                                          
      assert dt2._local_tensor.completed is False                                                                                          
                                                                                                                                           
      _ = fn(dt2)
                                                                                                                                           
      print(f"wait_tensor calls during runtime: {len(wait_tensor_calls)}")
      print(f"ACT.completed after compiled call: {dt2._local_tensor.completed}")                                                           
                                                                                                                                           
      assert len(wait_tensor_calls) >= 1, (                                                                                                
          "BUG: Inductor ran the compiled kernel without calling wait_tensor "                                                             
          "on the nested AsyncCollectiveTensor."                  
      )                                                                                                                                    
                                                                  
      dist.destroy_process_group()                                                                                                         
                                                                  

  if __name__ == "__main__":
      main()

---

wait_tensor calls during runtime: 0                                                                                                      
  ACT.completed after compiled call: False                        
  AssertionError: BUG: Inductor ran the compiled kernel without calling wait_tensor...

---

import os

  import torch
  import torch.distributed as dist
  import torch.distributed._functional_collectives as fc
  import torch.multiprocessing as mp                                                                                                       
  from torch.distributed._functional_collectives import AsyncCollectiveTensor
  from torch.distributed.device_mesh import init_device_mesh                                                                               
  from torch.distributed.tensor import DTensor, Replicate, Shard                                                                           
  
                                                                                                                                           
  def worker(rank: int, world_size: int) -> None:                 
      os.environ["MASTER_ADDR"] = "localhost"                                                                                              
      os.environ["MASTER_PORT"] = "29600"
      dist.init_process_group("gloo", rank=rank, world_size=world_size)                                                                    
                                                                                                                                           
      wait_calls: list[int] = []
      orig_wait_tensor = fc.wait_tensor                                                                                                    
                                                                                                                                           
      def counting_wait_tensor(t):
          wait_calls.append(1)                                                                                                             
          return orig_wait_tensor(t)                              

      fc.wait_tensor = counting_wait_tensor
                                                                                                                                           
      mesh = init_device_mesh("cpu", (world_size,))
                                                                                                                                           
      local_shard = torch.arange(8.0).view(8, 1) + rank * 100.0   
      dt_sharded = DTensor.from_local(local_shard, mesh, [Shard(0)])
                                                                                                                                           
      # Shard -> Replicate is an all-gather. async_op=True preserves the ACT
      # on _local_tensor (see torch/distributed/tensor/_redistribute.py:1733).                                                             
      dt_replicated = dt_sharded.redistribute(                                                                                             
          placements=[Replicate()], async_op=True
      )                                                                                                                                    
                                                                  
      if rank == 0:                                                                                                                        
          print(                                                  
              f"dt_replicated._local_tensor type: "
              f"{type(dt_replicated._local_tensor).__name__}"
          )                                                                                                                                
  
      @torch.compile(backend="inductor", fullgraph=True)                                                                                   
      def fn(x):                                                  
          return x + 1.0

      # Warmup/compile — Python dispatch here triggers the wait.
      # Clear the counter to isolate the runtime path.                                                                                     
      _ = fn(dt_replicated)
      wait_calls.clear()                                                                                                                   
                                                                  
      dt_sharded_b = DTensor.from_local(local_shard.clone(), mesh, [Shard(0)])
      dt_replicated_b = dt_sharded_b.redistribute(                                                                                         
          placements=[Replicate()], async_op=True
      )                                                                                                                                    
                                                                  
      _ = fn(dt_replicated_b)                                                                                                              
                                                                  
      if rank == 0:
          print(f"\nRuntime wait_tensor calls (should be >= 1): {len(wait_calls)}")
          if len(wait_calls) == 0:                                                                                                         
              print("BUG confirmed via public API.")
                                                                                                                                           
      dist.destroy_process_group()                                


  def main():
      world_size = 2                                                                                                                       
      mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
                                                                                                                                           
                                                                  
  if __name__ == "__main__":
      main()

---

dt_replicated._local_tensor type: AsyncCollectiveTensor                                                                                  
                                                                  
  Runtime wait_tensor calls (should be >= 1): 0
  BUG confirmed via public API.
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Summary

AOT autograd's process_inputs scans flat_args for top-level AsyncCollectiveTensor (ACT) and records act_input_indices so the compiled graph's inner_fn can emit args[i].trigger_wait() before the kernel runs.

This scan is non-recursive. When a DTensor carries an ACT in its _local_tensor slot (e.g., after DTensor.redistribute(async_op=True), or from FSDP2's async all-gather), the ACT is invisible to process_inputs, no trigger_wait() is emitted, and the Inductor kernel executes on memory whose underlying collective may not have completed.

With backend='eager' the bug is masked because ACT's __torch_dispatch__ (torch/distributed/_functional_collectives.py:1080) unwraps and waits inline on every op. Inductor bypasses __torch_dispatch__ entirely.

Minimal deterministic repro (single process)

Shows the missing wait_tensor on a synthesized ACT — easy to run, proves the mechanism.

import os                                                       

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
from torch.distributed._functional_collectives import AsyncCollectiveTensor                                                              
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate                                                                                  
                                                                
                                                                                                                                         
def main() -> None:                                             
    os.environ.setdefault("RANK", "0")
    os.environ.setdefault("WORLD_SIZE", "1")                                                                                             
    os.environ.setdefault("MASTER_ADDR", "localhost")
    os.environ.setdefault("MASTER_PORT", "29550")                                                                                        
    dist.init_process_group("gloo", rank=0, world_size=1)       
                                                                                                                                         
    wait_tensor_calls: list[str] = []                           
    orig_wait_tensor = fc.wait_tensor                                                                                                    
                                                                                                                                         
    def counting_wait_tensor(t):
        wait_tensor_calls.append("wait_tensor")                                                                                          
        return orig_wait_tensor(t)                              

    fc.wait_tensor = counting_wait_tensor
                                                                                                                                         
    mesh = init_device_mesh("cpu", (1,))
                                                                                                                                         
    @torch.compile(backend="inductor", fullgraph=True)          
    def fn(x):                                                                                                                           
        return x + 1
                                                                                                                                         
    # Compile + warmup. The compile path goes through Python dispatch, which
    # does trigger_wait; clear the counter afterwards to isolate runtime.                                                                
    dt = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])                                                                      
    dt._local_tensor = AsyncCollectiveTensor(dt._local_tensor.clone())                                                                   
    fn(dt)                                                                                                                               
    wait_tensor_calls.clear()                                                                                                            
                                                                                                                                         
    # Pure runtime call with a fresh ACT.                       
    dt2 = DTensor.from_local(torch.randn(4, 4), mesh, [Replicate()])                                                                     
    dt2._local_tensor = AsyncCollectiveTensor(dt2._local_tensor.clone())                                                                 
                                                                                                                                         
    assert isinstance(dt2._local_tensor, AsyncCollectiveTensor)                                                                          
    assert dt2._local_tensor.completed is False                                                                                          
                                                                                                                                         
    _ = fn(dt2)
                                                                                                                                         
    print(f"wait_tensor calls during runtime: {len(wait_tensor_calls)}")
    print(f"ACT.completed after compiled call: {dt2._local_tensor.completed}")                                                           
                                                                                                                                         
    assert len(wait_tensor_calls) >= 1, (                                                                                                
        "BUG: Inductor ran the compiled kernel without calling wait_tensor "                                                             
        "on the nested AsyncCollectiveTensor."                  
    )                                                                                                                                    
                                                                
    dist.destroy_process_group()                                                                                                         
                                                                

if __name__ == "__main__":
    main()

Output:

wait_tensor calls during runtime: 0                                                                                                      
ACT.completed after compiled call: False                        
AssertionError: BUG: Inductor ran the compiled kernel without calling wait_tensor...

Multi-rank repro through public API (2 processes)

Reaches the same state organically via DTensor.redistribute(async_op=True) — no manual _local_tensor assignment. Uses gloo on CPU so no GPU/NCCL required; the missing wait_tensor is stream-agnostic, so CUDA/NCCL behaves the same.

import os

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.multiprocessing as mp                                                                                                       
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.device_mesh import init_device_mesh                                                                               
from torch.distributed.tensor import DTensor, Replicate, Shard                                                                           

                                                                                                                                         
def worker(rank: int, world_size: int) -> None:                 
    os.environ["MASTER_ADDR"] = "localhost"                                                                                              
    os.environ["MASTER_PORT"] = "29600"
    dist.init_process_group("gloo", rank=rank, world_size=world_size)                                                                    
                                                                                                                                         
    wait_calls: list[int] = []
    orig_wait_tensor = fc.wait_tensor                                                                                                    
                                                                                                                                         
    def counting_wait_tensor(t):
        wait_calls.append(1)                                                                                                             
        return orig_wait_tensor(t)                              

    fc.wait_tensor = counting_wait_tensor
                                                                                                                                         
    mesh = init_device_mesh("cpu", (world_size,))
                                                                                                                                         
    local_shard = torch.arange(8.0).view(8, 1) + rank * 100.0   
    dt_sharded = DTensor.from_local(local_shard, mesh, [Shard(0)])
                                                                                                                                         
    # Shard -> Replicate is an all-gather. async_op=True preserves the ACT
    # on _local_tensor (see torch/distributed/tensor/_redistribute.py:1733).                                                             
    dt_replicated = dt_sharded.redistribute(                                                                                             
        placements=[Replicate()], async_op=True
    )                                                                                                                                    
                                                                
    if rank == 0:                                                                                                                        
        print(                                                  
            f"dt_replicated._local_tensor type: "
            f"{type(dt_replicated._local_tensor).__name__}"
        )                                                                                                                                

    @torch.compile(backend="inductor", fullgraph=True)                                                                                   
    def fn(x):                                                  
        return x + 1.0

    # Warmup/compile — Python dispatch here triggers the wait.
    # Clear the counter to isolate the runtime path.                                                                                     
    _ = fn(dt_replicated)
    wait_calls.clear()                                                                                                                   
                                                                
    dt_sharded_b = DTensor.from_local(local_shard.clone(), mesh, [Shard(0)])
    dt_replicated_b = dt_sharded_b.redistribute(                                                                                         
        placements=[Replicate()], async_op=True
    )                                                                                                                                    
                                                                
    _ = fn(dt_replicated_b)                                                                                                              
                                                                
    if rank == 0:
        print(f"\nRuntime wait_tensor calls (should be >= 1): {len(wait_calls)}")
        if len(wait_calls) == 0:                                                                                                         
            print("BUG confirmed via public API.")
                                                                                                                                         
    dist.destroy_process_group()                                


def main():
    world_size = 2                                                                                                                       
    mp.spawn(worker, args=(world_size,), nprocs=world_size, join=True)
                                                                                                                                         
                                                                
if __name__ == "__main__":
    main()

Output:

dt_replicated._local_tensor type: AsyncCollectiveTensor                                                                                  
                                                                
Runtime wait_tensor calls (should be >= 1): 0
BUG confirmed via public API.

Affected files

  • torch/_functorch/_aot_autograd/frontend_utils.py:73-77 — flat scan only.
  • torch/_functorch/_aot_autograd/subclass_codegen.py:269-271 — emits args[i] = args[i].trigger_wait() only for top-level indices.

Likely fix

  1. process_inputs: recursively walk __tensor_flatten__ on wrapper-subclass inputs and record ACT paths (e.g., (i, "_local_tensor")), not just flat indices.
  2. subclass_codegen: emit args[i].<path> = args[i].<path>.trigger_wait() for each recorded path, before subclass unwrapping runs.

Environment

Reproduces on main (bde493aade8), CPU + gloo; stream-agnostic so CUDA/NCCL is equivalent.

Correctness impact

Silent data corruption under real async collectives — the kernel reads a tensor whose collective is still in flight. The race often "wins" on localhost because collectives are fast, which is likely why this has gone uncaught. Most at risk: large tensors, slow interconnects (saturated PCIe / inter-node), contended GPUs.

Versions

Reproduces on main (bde493aade8), CPU + gloo; stream-agnostic so CUDA/NCCL is equivalent.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy @chauhang @penguinwu @tianyu-l @XilunWu @SherlockNoMad @ppwwyyxx @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The most likely fix involves modifying the process_inputs function to recursively walk __tensor_flatten__ on wrapper-subclass inputs and record AsyncCollectiveTensor paths.

Guidance

  • Modify process_inputs in torch/_functorch/_aot_autograd/frontend_utils.py to recursively scan for AsyncCollectiveTensors in nested tensors.
  • Update subclass_codegen in torch/_functorch/_aot_autograd/subclass_codegen.py to emit trigger_wait() calls for recorded AsyncCollectiveTensor paths.
  • Verify the fix by running the provided repro scripts and checking for the expected wait_tensor calls.
  • Test the fix with different tensor sizes, collective operations, and hardware configurations to ensure it covers various scenarios.

Example

# Modified process_inputs function
def process_inputs(args):
    act_input_paths = []
    for i, arg in enumerate(args):
        if hasattr(arg, '__tensor_flatten__'):
            for path, tensor in arg.__tensor_flatten__():
                if isinstance(tensor, AsyncCollectiveTensor):
                    act_input_paths.append((i, path))
        elif isinstance(arg, AsyncCollectiveTensor):
            act_input_paths.append((i,))
    # Record act_input_paths for later use

Notes

The fix involves modifying the process_inputs function to recursively scan for AsyncCollectiveTensors in nested tensors and recording their paths. The subclass_codegen function needs to be updated to emit trigger_wait() calls for these recorded paths. The provided repro scripts can be used to verify the fix.

Recommendation

Apply the workaround by modifying the process_inputs and subclass_codegen functions as described. This fix addresses the silent data corruption issue caused by the missing wait_tensor calls on nested AsyncCollectiveTensors.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix Inductor does not wait `AsyncCollectiveTensor` nested inside `DTensor._local_tensor` [3 comments, 2 participants]