pytorch - 💡(How to fix) Fix torch.export with prefer_deferred_runtime_asserts_over_guards=True crashes on DTensor-parallelized models (IndexError in set_missing_meta_vals)

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

Traceback (most recent call last):

Fix Action

Fix / Workaround

Workaround for #175467 (DTensorSpec not registered as pytree constant).

torch.utils._pytree.register_constant(
torch.distributed.tensor._dtensor_spec.DTensorSpec )

RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.export._trace._export(model, ..., prefer_deferred_runtime_asserts_over_guards=True) raises IndexError: list index out of range inside set_missing_meta_vals when model has parameters wrapped by parallelize_module(...) (i.e. they are DTensors).

The control case — the same model exported the same way without parallelize_module — succeeds. This pins the failure to parallelize_module-wrapped parameters interacting with _produce_aten_artifact's parameter/buffer accounting.

Minimal repro

A single nn.Linear(64, 64, bias=False) parallelized with ColwiseParallel() is enough. No attention, no reshape, no multi-layer model.

# Run with: torchrun --nproc_per_node=2 repro.py
import os                                                                                                                                                                                                                                                                                                                                                                                                 
import torch
import torch.distributed as dist                                                                                                                                                                                                                                                                                                                                                                          
import torch.distributed.tensor._dtensor_spec
import torch.export._trace  # noqa: F401  (submodule import needed for _export access)
import torch.utils._pytree                                                                                                                                                                                                                                                                                                                                                                                
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module                                                                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                                                                                                                          
# Workaround for #175467 (DTensorSpec not registered as pytree constant).
torch.utils._pytree.register_constant(                                                                                                                                                                                                                                                                                                                                                                    
    torch.distributed.tensor._dtensor_spec.DTensorSpec
)                                                                                                                                                                                                                                                                                                                                                                                                         
                
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)                                                                                                                                                                                                                                                                                                                                                                         
dist.init_process_group("nccl")
mesh = init_device_mesh("cuda", (dist.get_world_size(),), mesh_dim_names=("tp",))                                                                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                                                                                                                          

class Tiny(torch.nn.Module):                                                                                                                                                                                                                                                                                                                                                                              
    def __init__(self):                                                                                                                                                                                                                                                                                                                                                                                   
        super().__init__()
        self.lin = torch.nn.Linear(64, 64, bias=False)                                                                                                                                                                                                                                                                                                                                                    
                
    def forward(self, x):
        return self.lin(x)
                                                                                                                                                                                                                                                                                                                                                                                                          

def export_with_deferred_asserts(model, x):                                                                                                                                                                                                                                                                                                                                                               
    seq_len = torch.export.Dim("seq_len", min=1, max=128)
    return torch.export._trace._export(
        model,                                                                                                                                                                                                                                                                                                                                                                                            
        args=(x,),
        dynamic_shapes=({1: seq_len},),                                                                                                                                                                                                                                                                                                                                                                   
        strict=False,
        prefer_deferred_runtime_asserts_over_guards=True,
    )                                                                                                                                                                                                                                                                                                                                                                                                     

                                                                                                                                                                                                                                                                                                                                                                                                          
x = torch.randn(1, 7, 64, device="cuda")
rank = dist.get_rank()

# Control: plain model (no DTensor) — SUCCEEDS.
plain = Tiny().cuda()                                                                                                                                                                                                                                                                                                                                                                                     
export_with_deferred_asserts(plain, x)
print(f"[Rank {rank}] Control: SUCCESS")                                                                                                                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                                                                                                                                          
# Bug: same model, parallelize_module applied — RAISES IndexError.
dt = Tiny().cuda()                                                                                                                                                                                                                                                                                                                                                                                        
parallelize_module(dt, mesh, {"lin": ColwiseParallel()})
export_with_deferred_asserts(dt, x)   # IndexError here                                                                                                                                                                                                                                                                                                                                                   
print(f"[Rank {rank}] DTensor case: SUCCESS  (should not reach)")                                                                                                                                                                                                                                                                                                                                         

Expected output: control line prints, DTensor line raises IndexError.                                                                                                                                                                                                                                                                                                                                     
                
Stack trace                                                                                                                                                                                                                                                                                                                                                                                               
                
Traceback (most recent call last):
  File ".../torch/export/_trace.py", line 1310, in wrapper                                                                                                                                                                                                                                                                                                                                                
    ep = fn(*args, **kwargs)
  File ".../torch/export/exported_program.py", line 124, in wrapper                                                                                                                                                                                                                                                                                                                                       
    return fn(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                            
  File ".../torch/_utils_internal.py", line 96, in wrapper_function
    return function(*args, **kwargs)                                                                                                                                                                                                                                                                                                                                                                      
  File ".../torch/export/_trace.py", line 2537, in _export
    export_artifact = export_func(...)                                                                                                                                                                                                                                                                                                                                                                    
  File ".../torch/export/_trace.py", line 2229, in _non_strict_export                                                                                                                                                                                                                                                                                                                                     
    aten_export_artifact = _to_aten_func(...)
  File ".../torch/export/_trace.py", line 1084, in _export_to_aten_ir                                                                                                                                                                                                                                                                                                                                     
    return _produce_aten_artifact(...)
  File ".../torch/export/_trace.py", line 631, in _produce_aten_artifact                                                                                                                                                                                                                                                                                                                                  
    set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs)
  File ".../torch/export/_trace.py", line 2062, in set_missing_meta_vals                                                                                                                                                                                                                                                                                                                                  
    user_arg = flat_args[index - num_params_buffers]
               ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                                                                                                                                                                                      
IndexError: list index out of range                                                                                                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                            
                
  

### Versions

- torch/export/_trace.py:631 — _produce_aten_artifact calls set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs).                                                                                                                                                                                                                                                                             
- torch/export/_trace.py:2062 — the crashing line; the indexing math assumes num_params_buffers equals the count of lifted-param placeholders.
- The bug is upstream of line 2062: total_non_user_inputs (passed as num_params_buffers) appears to undercount DTensor-wrapped parameters. The fix is in whichever pass produces that count for _produce_aten_artifact.                                                                                                                                                                                   

Versions                                                                                                                                                                                                                                                                                                                                                                                                  
                
PyTorch:        2.13.0.dev20260515+cu130  (nightly)                                                                                                                                                                                                                                                                                                                                                       
Python:         3.11.15                                                                                                                                                                                                                                                                                                                                                                                   
World size:     2  (any TP > 1 should reproduce)

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix torch.export with prefer_deferred_runtime_asserts_over_guards=True crashes on DTensor-parallelized models (IndexError in set_missing_meta_vals)