pytorch - ✅(Solved) Fix [pipelining] PP fails when trying to pipeline any HF transformers model that uses RoPE [3 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#184361Fetched 2026-05-20 03:39:07
View on GitHub
Comments
1
Participants
1
Timeline
30
Reactions
0
Participants
Timeline (top)
mentioned ×12subscribed ×12labeled ×5commented ×1

Error Message

(pt) fyguan@beast:~/disk20$ torchrun --nproc-per-node 2 pipelining_model_backward.py W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] ***************************************** W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] ***************************************** [Rank 1] Starting model init on cuda:1 [Rank 0] Starting model init on cuda:0 Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 41777.16it/s] Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 38255.34it/s] LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(49152, 576) (layers): ModuleList( (0-29): 30 x LlamaDecoderLayer( (self_attn): LlamaAttention( (q_proj): Linear(in_features=576, out_features=576, bias=False) (k_proj): Linear(in_features=576, out_features=192, bias=False) (v_proj): Linear(in_features=576, out_features=192, bias=False) (o_proj): Linear(in_features=576, out_features=576, bias=False) ) (mlp): LlamaMLP( (gate_proj): Linear(in_features=576, out_features=1536, bias=False) (up_proj): Linear(in_features=576, out_features=1536, bias=False) (down_proj): Linear(in_features=1536, out_features=576, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): LlamaRMSNorm((576,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05) ) ) (norm): LlamaRMSNorm((576,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=576, out_features=49152, bias=False) ) LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(49152, 576) (layers): ModuleList( (0-29): 30 x LlamaDecoderLayer( (self_attn): LlamaAttention( (q_proj): Linear(in_features=576, out_features=576, bias=False) (k_proj): Linear(in_features=576, out_features=192, bias=False) (v_proj): Linear(in_features=576, out_features=192, bias=False) (o_proj): Linear(in_features=576, out_features=576, bias=False) ) (mlp): LlamaMLP( (gate_proj): Linear(in_features=576, out_features=1536, bias=False) (up_proj): Linear(in_features=576, out_features=1536, bias=False) (down_proj): Linear(in_features=1536, out_features=576, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): LlamaRMSNorm((576,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05) ) ) (norm): LlamaRMSNorm((576,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): Linear(in_features=576, out_features=49152, bias=False) ) /mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: isinstance(treespec, LeafSpec) is deprecated, use isinstance(treespec, TreeSpec) and treespec.is_leaf() instead. return cls.new(cls, *args) /mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: isinstance(treespec, LeafSpec) is deprecated, use isinstance(treespec, TreeSpec) and treespec.is_leaf() instead. return cls.new(cls, *args) [rank0]: Traceback (most recent call last): [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 328, in placeholder [rank0]: return next(self.args_iter) [rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]: StopIteration

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last): [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 892, in forward_one_chunk [rank0]: output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 762, in forward_maybe_with_nosync [rank0]: out_val = self.submod(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped [rank0]: return self._wrapped_call(self, *args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 461, in call [rank0]: raise e [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 447, in call [rank0]: return super(self.cls, obj).call(*args, **kwargs) # type: ignore[misc] [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "<eval_with_key>.474", line 6, in forward [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward [rank0]: return torch.fx.Interpreter(self, graph=self.graph).run( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run [rank0]: self.env[node] = self.run_node(node) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node [rank0]: return getattr(self, n.op)(n.target, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 431, in call_module [rank0]: return submod(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward [rank0]: return torch.fx.Interpreter(self, graph=self.graph).run( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run [rank0]: self.env[node] = self.run_node(node) [rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node [rank0]: return getattr(self, n.op)(n.target, args, kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 333, in placeholder [rank0]: raise RuntimeError( [rank0]: RuntimeError: Expected positional argument for parameter unsqueeze_default, but one was not passed in!

[rank0]: While executing %unsqueeze_default : [num_users=1] = placeholder[target=unsqueeze_default] [rank0]: Original traceback: [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward [rank0]: outputs: BaseModelOutputWithPast = self.model( [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 397, in forward [rank0]: position_ids = position_ids.unsqueeze(0) [rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: While executing %rotary_emb : [num_users=1] = call_module[target=rotary_emb](args = (), kwargs = {}) [rank0]: Original traceback: [rank0]: None [rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last): [rank0]: File "/mnt/disk20/user/fyguan/pipelining_model_backward.py", line 79, in <module> [rank0]: schedule.step( [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 604, in step [rank0]: self._step_microbatches( [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 873, in _step_microbatches [rank0]: output = self._stage.forward_one_chunk( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 900, in forward_one_chunk [rank0]: raise RuntimeError(exc_msg) from e [rank0]: RuntimeError: [rank0]: [Stage 0] failed to run forward: [rank0]: args: () [rank0]: kwargs: {'input_ids': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.int64)', 'attention_mask': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.bool)'} [rank0]:
[rank0]:[W519 11:42:51.791357236 ProcessGroupNCCL.cpp:1575] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) W0519 11:42:54.151000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:1012] Sending process 54075 closing signal SIGTERM E0519 11:42:54.469000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 54074) of binary: /mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/python3.12 Traceback (most recent call last): File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/torchrun", line 6, in <module> sys.exit(main()) ^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 362, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 990, in main run(args) File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 981, in run elastic_launch( File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in call return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

pipelining_model_backward.py FAILED

Failures: [1]: time : 2026-05-19_11:42:54 host : beast rank : 1 (local_rank: 1) exitcode : -15 (pid: 54075) error_file: <N/A> traceback : Signal 15 (SIGTERM) received by PID 54075

Root Cause (first observed failure): [0]: time : 2026-05-19_11:42:54 host : beast rank : 0 (local_rank: 0) exitcode : 1 (pid: 54074) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Root Cause

[rank0]: Traceback (most recent call last): [rank0]: File "/mnt/disk20/user/fyguan/pipelining_model_backward.py", line 79, in <module> [rank0]: schedule.step( [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 604, in step [rank0]: self._step_microbatches( [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 873, in _step_microbatches [rank0]: output = self._stage.forward_one_chunk( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 900, in forward_one_chunk [rank0]: raise RuntimeError(exc_msg) from e [rank0]: RuntimeError: [rank0]: [Stage 0] failed to run forward: [rank0]: args: () [rank0]: kwargs: {'input_ids': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.int64)', 'attention_mask': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.bool)'} [rank0]:
[rank0]:[W519 11:42:51.791357236 ProcessGroupNCCL.cpp:1575] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator()) W0519 11:42:54.151000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:1012] Sending process 54075 closing signal SIGTERM E0519 11:42:54.469000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 54074) of binary: /mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/python3.12 Traceback (most recent call last): File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/torchrun", line 6, in <module> sys.exit(main()) ^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 362, in wrapper return f(*args, **kwargs) ^^^^^^^^^^^^^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 990, in main run(args) File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 981, in run elastic_launch( File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in call return launch_agent(self._config, self._entrypoint, list(args)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

pipelining_model_backward.py FAILED

Failures: [1]: time : 2026-05-19_11:42:54 host : beast rank : 1 (local_rank: 1) exitcode : -15 (pid: 54075) error_file: <N/A> traceback : Signal 15 (SIGTERM) received by PID 54075

Root Cause (first observed failure): [0]: time : 2026-05-19_11:42:54 host : beast rank : 0 (local_rank: 0) exitcode : 1 (pid: 54074) error_file: <N/A> traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Fix Action

Fix / Workaround

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 52 bits physical, 57 bits virtual Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Vendor ID: AuthenticAMD Model name: AMD EPYC 9555 64-Core Processor CPU family: 26 Model: 2 Thread(s) per core: 2 Core(s) per socket: 64 Socket(s): 2 Stepping: 1 Frequency boost: enabled CPU(s) scaling MHz: 84% CPU max MHz: 4410.8110 CPU min MHz: 1210.8110 BogoMIPS: 6391.45 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid cqm rdt_a avx512f avx512dq adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca avx512_vp2intersect flush_l1d debug_swap amd_lbr_pmc_freeze Virtualization: AMD-V L1d cache: 6 MiB (128 instances) L1i cache: 4 MiB (128 instances) L2 cache: 128 MiB (128 instances) L3 cache: 512 MiB (16 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-63,128-191 NUMA node1 CPU(s): 64-127,192-255 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Old microcode: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Mitigation; Reduced Speculation Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

PR fix notes

PR #182182: [pipelining] Fix None gradient handling in backward send/recv

Description (problem / solution / changelog)

Pipeline backward metadata now allocates gradient recv buffers conservatively for differentiable plain tensors, so later microbatches can send real gradients even when the inference microbatch produced None. DTensor gradient metadata is not invented from activation metadata; if inference records None and runtime later produces a DTensor gradient, runtime errors instead of silently dropping or guessing its placement.

Runtime backward sends now key off metadata allocation: they send real gradients when available, send zeros when metadata captured a grad slot but runtime returns None, and error when a real gradient appears with no recv buffer allocated.

This also aligns split input-backward with full backward by filtering non-differentiable outputs before autograd.grad and allowing unused inputs, which is needed by ZBV-style schedules when a microbatch has no boundary input gradient. The old verbose stage-only tests were replaced with concise schedule coverage for GPipe, Interleaved1F1B, and ZBV over both mismatch directions, plus DTensor unit coverage for the no-implicit-metadata runtime contract.

Fixes https://github.com/pytorch/pytorch/issues/152827 Fixes https://github.com/pytorch/torchtitan/issues/2748

Authored with Codex.

Changed files

  • test/distributed/pipelining/model_registry.py (modified, +32/-0)
  • test/distributed/pipelining/test_dtensor_pp_unit_tests.py (modified, +54/-0)
  • test/distributed/pipelining/test_schedule_multiproc.py (modified, +109/-1)
  • torch/distributed/pipelining/_backward.py (modified, +21/-13)
  • torch/distributed/pipelining/_utils.py (modified, +15/-7)
  • torch/distributed/pipelining/stage.py (modified, +74/-12)

PR #182644: Fix pipelining crash when split_module interleaves get_attr with placeholder

Description (problem / solution / changelog)

Fixes #162898

split_module can produce partition subgraphs where get_attr nodes are interleaved with placeholder nodes. _ModuleFrame.run_outer() assumes placeholders form a contiguous prefix, so it crashes when this invariant is violated.

added _move_placeholders_to_front to reorder the graph before _outline_submodules runs and added tests.

also verified with the actual SmolLM2 model from the issue

Changed files

  • test/distributed/pipelining/test_unflatten.py (modified, +51/-0)
  • torch/distributed/pipelining/_IR.py (modified, +22/-0)

PR #183582: [pipelining] Add guards for non-float tensors when building pipeline

Description (problem / solution / changelog)

Fixes #183024. Tested working with code below by manually applying the patch over 2.11. Analysis and initial patches found using AI. Manually applied and tested patches.

<details> <summary> Test code, run with torchrun on 2 gpus </summary>
import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, Schedule1F1B
from transformers import AutoModelForCausalLM
import os

# Initialize torchrun's distributed environment
pp_group = dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])

# Assign this specific process to its designated GPU
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)

print(f"[Rank {rank}] Starting model init on {device}")

# GPT2
model_id = "openai-community/gpt2"

device_count = torch.cuda.device_count()
model = AutoModelForCausalLM.from_pretrained(model_id)
print(model)

# kv cache is enabled by default and causes other issues with dynamo tracing/pipelining, disabling
model.config.use_cache = False

# run on 2 or more gpus
split_spec = {
    f"transformer.h.5": SplitPoint.END,
} 
model = model.to(device)

pipe_module = pipeline(
    module=model,
    mb_args=(),
    mb_kwargs={
        'input_ids': torch.zeros((1, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((1, 512), dtype=torch.bool).to(device)
    },
    split_spec=split_spec,
)

stage = pipe_module.build_stage(rank, device, pp_group)


# placeholder loss function
def loss_fn(outputs, other):
    if isinstance(outputs, tuple):
        return outputs[0]
    elif isinstance(outputs, torch.Tensor):
        logits = outputs
    else:
        logits = outputs['logits']
    loss = logits.sum() - other.sum()
    return loss

schedule = Schedule1F1B(stage, n_microbatches=4, loss_fn=loss_fn)

if rank == 0:
    inputs = {
        'input_ids': torch.zeros((4, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((4, 512), dtype=torch.bool).to(device),
    }

    schedule.step(
        **inputs
    )
elif rank == world_size - 1:
    losses = []
    output = schedule.step(target=torch.randn(4, 512).to(device), losses=losses)
    print(f"losses: {losses}")
else:
    schedule.step()
</details> <details> <summary> Prompt and Initial AI output </summary>

Prompt:

Analyze, triage, and provide me with a fix for this github issue

https://github.com/pytorch/pytorch/issues/183024

Z.ai agent response:

I'll analyze this GitHub issue for you. Let me start by fetching the issue details and understanding the problem. Let me extract the actual issue content more cleanly. Now let me look at the PyTorch source code for the specific failing function to understand the bug deeply. Now I have a complete picture. Let me also verify the related backward code path to ensure the fix is comprehensive. I now have a comprehensive understanding of the issue. Here's my full analysis, triage, and proposed fix:


🔍 Issue Analysis: PyTorch #183024

Title: [pipelining] Unable to use PP with HF transformers models due to non-float inputs

Bug Summary

When using PyTorch's Pipeline Parallelism (torch.distributed.pipelining) with HuggingFace transformer models (GPT2, LLaMA, etc.), the pipeline crashes with:

RuntimeError: only Tensors of floating point dtype can require gradients

🧬 Root Cause

The root cause is a dtype-agnostic requires_grad assignment in the pipeline stage infrastructure. When has_backward=True (training mode), the code unconditionally sets requires_grad=True on all activation receive/send buffers — including non-float tensors like input_ids (torch.long) and attention_mask (torch.bool). PyTorch only allows floating-point tensors to require gradients (a fundamental constraint since Issue #37680, open since 2020).

The crash occurs at three specific locations:

Location 1: torch/distributed/pipelining/stage.pycreate_recv_tensor inside _create_act_recv_info (PRIMARY CRASH)

# Lines ~1361-1377 (broken)
tensor_meta = _TensorMeta(
    shape=example_value.shape,
    stride=example_value.stride(),
    dtype=example_value.dtype,
    requires_grad=self.has_backward,  # ← BUG: True for ALL dtypes
)
buffer = _make_tensor_from_meta(tensor_meta, self.device)
if self.has_backward:
    buffer.requires_grad_(True)  # ← CRASH: non-float tensors can't require grad

This is the exact crash point from the traceback. When the pipeline creates receive buffers for intermediate activations between stages, it encounters input_ids (dtype=torch.long) or attention_mask (dtype=torch.bool) and tries to set requires_grad_(True), which is illegal.

Location 2: torch/distributed/pipelining/stage.py_create_act_send_info (SECONDARY)

# Lines ~1462-1468 (broken)
output_metas.append(
    _TensorMeta(
        shape=val.shape,
        stride=val.stride(),
        dtype=val.dtype,
        requires_grad=self.has_backward,  # ← Same bug for outputs
    )
)

Location 3: torch/distributed/pipelining/_utils.py_TensorMeta.to_tensor (LATENT)

# Line ~96 (broken)
t = _make_tensor_from_meta(self, device)
t.requires_grad_(self.requires_grad)  # ← Crashes if requires_grad=True and dtype is non-float
return t

Additional Related Bug: torch/distributed/pipelining/_backward.py_get_grad_fn_or_grad_acc

Even after fixing the above, the backward pass has a second bug waiting to surface. The function _get_grad_fn_or_grad_acc assumes its argument is always a torch.Tensor:

def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
    if t.requires_grad and t.grad_fn is None:  # ← AttributeError if t is not a Tensor
        ...

This is called via map(_get_grad_fn_or_grad_acc, input_values) in stage_backward_input, where input_values can contain non-tensor objects. While the immediate crash is masked by Location 1, once that's fixed, this will be the next crash site.


🔧 Proposed Fix

The fix is to guard requires_grad=True with a tensor.is_floating_point() check at all affected locations. This is semantically correct because:

  • Non-float tensors (integer, boolean) never produce meaningful gradients
  • Their gradient in the backward pass should be None
  • This is consistent with PyTorch's autograd semantics

Patch 1: torch/distributed/pipelining/stage.pycreate_recv_tensor

def create_recv_tensor(placeholder, arg_node):
    example_value = placeholder.meta["val"]

    # ... (DTensor check unchanged) ...

    if arg_node.op == "placeholder":
        return _RecvInfo(
            input_name=f"root_input_{placeholder.name}",
            source=None,
            buffer=None,
            tensor_meta=_TensorMeta.from_tensor(example_value),
            is_root_arg=True,
        )

    # Received activation from a previous stage.
    while arg_node.target is operator.getitem:
        arg_node = arg_node.args[0]

    if arg_node.op != "call_module":
        raise PipeliningMetadataError(
            f"Expecting call_module, got {arg_node.op}"
        )
    src_stage = self.get_stage_index_of_submod(arg_node.name)

    # FIX: Only set requires_grad for floating-point tensors
    needs_grad = self.has_backward and example_value.is_floating_point()

    tensor_meta = _TensorMeta(
        shape=example_value.shape,
        stride=example_value.stride(),
        dtype=example_value.dtype,
        requires_grad=needs_grad,  # ← was: self.has_backward
    )

    buffer = _make_tensor_from_meta(tensor_meta, self.device)
    if needs_grad:  # ← was: if self.has_backward
        buffer.requires_grad_(True)

    return _RecvInfo(
        arg_node.name,
        src_stage,
        buffer,
        tensor_meta,
    )

Patch 2: torch/distributed/pipelining/stage.py_create_act_send_info

# Inside _create_act_send_info, where output_metas are built:
output_metas.append(
    _TensorMeta(
        shape=val.shape,
        stride=val.stride(),
        dtype=val.dtype,
        requires_grad=self.has_backward and val.is_floating_point(),  # ← was: self.has_backward
    )
)

Patch 3: torch/distributed/pipelining/_utils.py_TensorMeta.to_tensor

def to_tensor(self, device: torch.device | str) -> torch.Tensor:
    t = _make_tensor_from_meta(self, device)
    if self.requires_grad and t.is_floating_point():  # ← was: unconditionally
        t.requires_grad_(True)
    return t

Patch 4: torch/distributed/pipelining/_backward.py_get_grad_fn_or_grad_acc

def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None:
    # FIX: Guard against non-tensor inputs
    if not isinstance(t, torch.Tensor):
        return None
    if t.requires_grad and t.grad_fn is None:
        viewed_t = t.view_as(t)
        grad_fn = viewed_t.grad_fn
        if grad_fn is not None:
            return grad_fn.next_functions[0][0]
        else:
            raise RuntimeError(
                "Attempted to get grad_fn, but got None."
                "Is this being created in a no-grad context?"
            )
    else:
        return t.grad_fn

📊 Triage Assessment

AspectRating
Severity🔴 High — Completely blocks pipeline parallelism for any model with non-float inputs (all HF transformer models)
Impact🔴 High — Affects GPT2, LLaMA, BERT, T5 — essentially all popular transformer architectures
Workaround?🟡 Difficult — Would require rewriting models to avoid passing non-float tensors between stages, or monkey-patching the pipelining internals
Fix Complexity🟢 Low — Four surgical is_floating_point() / isinstance guards, no architectural changes
Regression Risk🟢 Low — The fix only restricts requires_grad for cases that already crash; float tensors are unaffected

Why PRs #182182 and #182644 Don't Fix This

  • PR #182182 fixes None gradient handling for float tensors that are unused — it doesn't address the dtype mismatch
  • PR #182644 fixes FX graph node ordering — it doesn't touch the requires_grad logic at all
  • This is a distinct, orthogonal bug that needs its own fix

🚀 Suggested Next Steps

  1. File a PR against pytorch/pytorch with the four patches above
  2. Add regression tests that pipeline an HF model (even a tiny one like GPT2-small) with input_ids and attention_mask
  3. Coordinate with PR #182182 — both fixes are needed for full pipeline parallelism support, and they should be tested together since they both touch the backward path

Prompt:

Check your work

Z.ai agent response:

Good call — let me verify my analysis against the actual PyTorch source code. I'll cross-check each claim. Now I have verified all the source code against my original analysis. Let me also check one more thing — whether the to_tensor method is actually called in the crash path or not. Now let me verify the secondary bug claim about _get_grad_fn_or_grad_acc — whether it would actually be hit in the backward path after fixing the primary bug. Good — my verification is complete. Here's the corrected and refined analysis:


✅ Verification Results

Claims Confirmed ✅

ClaimStatusEvidence
Primary crash: create_recv_tensor sets requires_grad=self.has_backward unconditionallyConfirmedExact match with source at stage.py lines ~1361-1377
Primary crash: buffer.requires_grad_(True) when self.has_backwardConfirmedLine ~1377, crashes for torch.long/torch.bool tensors
Secondary bug: _create_act_send_info uses same requires_grad=self.has_backward patternConfirmedLine ~1467, same unconditional pattern for outputs
Latent bug: _TensorMeta.to_tensor() calls t.requires_grad_(self.requires_grad) unconditionallyConfirmed_utils.py line ~96, same crash pattern
PR #182182 does NOT fix thisConfirmedIt fixes None gradient handling for float tensors, not the dtype mismatch
PR #182644 does NOT fix thisConfirmedIt fixes FX graph node ordering, doesn't touch requires_grad logic
is_floating_point() guard is the correct fixConfirmedSemantically correct: non-float tensors can never produce meaningful gradients in PyTorch

Claim Clarified ⚠️

ClaimStatusCorrection
_get_grad_fn_or_grad_acc will crash next after the primary fix⚠️ Partially correctThis IS a real bug, but it's for a different scenario — when non-Tensor Python objects (int, bool, None) appear in input_values. In the HF GPT2 pipeline case specifically, input_values contains torch.Tensor objects (just non-float ones), which after the primary fix will have requires_grad=False and work fine. The isinstance guard is still a valid defense-in-depth fix but is not directly triggered by this issue.

Revised Fix Assessment

Patch 1 (create_recv_tensor)Essential, directly fixes the crash

Patch 2 (_create_act_send_info)Important, prevents same crash for non-float outputs crossing stage boundaries
(Not triggered in GPT2 specifically since embedding outputs are float, but needed for models where non-float tensors cross stages)

Patch 3 (_TensorMeta.to_tensor)Defensive, prevents latent crash if this method is called with requires_grad=True on non-float metadata

Patch 4 (_get_grad_fn_or_grad_acc)Separate bug, defense-in-depth ⚠️
Should be framed as a separate fix for the torchtitan #3112 scenario (non-Tensor Python objects in input_values), not as part of the immediate fix for issue #183024.


Bottom Line

My original analysis was substantially correct. The three core patches (1, 2, 3) are accurate and necessary. Patch 4 is a valid bug fix but should be presented as an orthogonal improvement rather than a direct consequence of fixing issue #183024. The root cause, severity assessment, and fix approach all hold up under verification.

</details>

Changed files

  • torch/distributed/pipelining/_utils.py (modified, +2/-1)
  • torch/distributed/pipelining/stage.py (modified, +4/-3)

Code Example

import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, Schedule1F1B
from transformers import AutoModelForCausalLM
import os

# Initialize torchrun's distributed environment
pp_group = dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])

# Assign this specific process to its designated GPU
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)

print(f"[Rank {rank}] Starting model init on {device}")

#model_type = 'gpt2'
model_type = 'llama'

if model_type == 'gpt2':
    # GPT2
    model_id = "openai-community/gpt2"

elif model_type == 'llama':
    # SmolLM (llama style, uses modeling_llama.py)
    model_id = "HuggingFaceTB/SmolLM2-135M"

device_count = torch.cuda.device_count()
model = AutoModelForCausalLM.from_pretrained(model_id)
print(model)

# kv cache is enabled by default and causes other issues with dynamo tracing/pipelining, disabling
model.config.use_cache = False

# run on 2 or more gpus
if model_type == 'gpt2':
    # GPT2
    split_spec = {f"transformer.h.5": SplitPoint.END,}

elif model_type == 'llama':
    # SmolLM/llama-style models
    split_spec = {f"model.layers.14": SplitPoint.END,}

model = model.to(device)

pipe_module = pipeline(
    module=model,
    mb_args=(),
    mb_kwargs={
        'input_ids': torch.zeros((1, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((1, 512), dtype=torch.bool).to(device),
    },
    split_spec=split_spec,
)

stage = pipe_module.build_stage(rank, device, pp_group)

# placeholder loss function
def loss_fn(outputs, other):
    if isinstance(outputs, tuple):
        return outputs[0]
    elif isinstance(outputs, torch.Tensor):
        logits = outputs
    else:
        logits = outputs['logits']
    loss = logits.sum() - other.sum()
    return loss

schedule = Schedule1F1B(stage, n_microbatches=4, loss_fn=loss_fn)

if rank == 0:
    inputs = {
        'input_ids': torch.zeros((4, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((4, 512), dtype=torch.bool).to(device),
    }

    schedule.step(
        **inputs,
    )
elif rank == world_size - 1:
    losses = []
    output = schedule.step(target=torch.randn(4, 512).to(device), losses=losses)
    print(f"losses: {losses}")
else:
    schedule.step()

---

(pt) fyguan@beast:~/disk20$ torchrun --nproc-per-node 2 pipelining_model_backward.py 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] *****************************************
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] *****************************************
[Rank 1] Starting model init on cuda:1
[Rank 0] Starting model init on cuda:0
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 41777.16it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 38255.34it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 328, in placeholder
[rank0]:     return next(self.args_iter)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]: StopIteration

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 892, in forward_one_chunk
[rank0]:     output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 762, in forward_maybe_with_nosync
[rank0]:     out_val = self.submod(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
[rank0]:     return self._wrapped_call(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 461, in __call__
[rank0]:     raise e
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 447, in __call__
[rank0]:     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "<eval_with_key>.474", line 6, in forward
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward
[rank0]:     return torch.fx.Interpreter(self, graph=self.graph).run(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 431, in call_module
[rank0]:     return submod(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward
[rank0]:     return torch.fx.Interpreter(self, graph=self.graph).run(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 333, in placeholder
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Expected positional argument for parameter unsqueeze_default, but one was not passed in!

[rank0]: While executing %unsqueeze_default : [num_users=1] = placeholder[target=unsqueeze_default]
[rank0]: Original traceback:
[rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank0]:     outputs: BaseModelOutputWithPast = self.model(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 397, in forward
[rank0]:     position_ids = position_ids.unsqueeze(0)
[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: While executing %rotary_emb : [num_users=1] = call_module[target=rotary_emb](args = (), kwargs = {})
[rank0]: Original traceback:
[rank0]: None
[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/pipelining_model_backward.py", line 79, in <module>
[rank0]:     schedule.step(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 604, in step
[rank0]:     self._step_microbatches(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 873, in _step_microbatches
[rank0]:     output = self._stage.forward_one_chunk(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 900, in forward_one_chunk
[rank0]:     raise RuntimeError(exc_msg) from e
[rank0]: RuntimeError: 
[rank0]:             [Stage 0] failed to run forward:
[rank0]:             args: ()
[rank0]:             kwargs: {'input_ids': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.int64)', 'attention_mask': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.bool)'}
[rank0]:             
[rank0]:[W519 11:42:51.791357236 ProcessGroupNCCL.cpp:1575] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0519 11:42:54.151000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:1012] Sending process 54075 closing signal SIGTERM
E0519 11:42:54.469000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 54074) of binary: /mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/python3.12
Traceback (most recent call last):
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
pipelining_model_backward.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2026-05-19_11:42:54
  host      : beast
  rank      : 1 (local_rank: 1)
  exitcode  : -15 (pid: 54075)
  error_file: <N/A>
  traceback : Signal 15 (SIGTERM) received by PID 54075
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-05-19_11:42:54
  host      : beast
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 54074)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

---

Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:50:00) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.18.20-fyguan-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA H200 NVL
GPU 1: NVIDIA H200 NVL
GPU 2: NVIDIA H200 NVL
GPU 3: NVIDIA H200 NVL

Nvidia driver version: 590.48.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           52 bits physical, 57 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  256
On-line CPU(s) list:                     0-255
Vendor ID:                               AuthenticAMD
Model name:                              AMD EPYC 9555 64-Core Processor
CPU family:                              26
Model:                                   2
Thread(s) per core:                      2
Core(s) per socket:                      64
Socket(s):                               2
Stepping:                                1
Frequency boost:                         enabled
CPU(s) scaling MHz:                      84%
CPU max MHz:                             4410.8110
CPU min MHz:                             1210.8110
BogoMIPS:                                6391.45
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid cqm rdt_a avx512f avx512dq adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca avx512_vp2intersect flush_l1d debug_swap amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               6 MiB (128 instances)
L1i cache:                               4 MiB (128 instances)
L2 cache:                                128 MiB (128 instances)
L3 cache:                                512 MiB (16 instances)
NUMA node(s):                            2
NUMA node0 CPU(s):                       0-63,128-191
NUMA node1 CPU(s):                       64-127,192-255
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; Reduced Speculation
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] numpy==2.4.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.11.0+cu130
[pip3] torchaudio==2.11.0+cu130
[pip3] torchvision==0.26.0+cu130
[pip3] triton==3.6.0
[conda] cuda-nvrtc                 13.2.51          hecca717_0            conda-forge
[conda] libcufft                   12.2.0.37        hecca717_0            conda-forge
[conda] numpy                      2.3.5            pypi_0                pypi
[conda] nvidia-cublas              13.1.0.3         pypi_0                pypi
[conda] nvidia-cuda-cupti          13.0.85          pypi_0                pypi
[conda] nvidia-cuda-nvrtc          13.0.88          pypi_0                pypi
[conda] nvidia-cuda-runtime        13.0.96          pypi_0                pypi
[conda] nvidia-cudnn-cu13          9.19.0.56        pypi_0                pypi
[conda] nvidia-cufft               12.0.0.61        pypi_0                pypi
[conda] nvidia-curand              10.4.0.35        pypi_0                pypi
[conda] nvidia-cusolver            12.0.4.66        pypi_0                pypi
[conda] nvidia-cusparse            12.6.3.3         pypi_0                pypi
[conda] nvidia-cusparselt-cu13     0.8.0            pypi_0                pypi
[conda] nvidia-nccl-cu13           2.28.9           pypi_0                pypi
[conda] nvidia-nvjitlink           13.0.88          pypi_0                pypi
[conda] nvidia-nvtx                13.0.85          pypi_0                pypi
[conda] torch                      2.11.0+cu130     pypi_0                pypi
[conda] torchaudio                 2.11.0+cu130     pypi_0                pypi
[conda] torchvision                0.26.0+cu130     pypi_0                pypi
[conda] triton                     3.6.0            pypi_0                pypi
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Attempting to pipeline a HF transformers model that uses a llama-style RoPE impelementation (basically all post-2022 transformer-based causal LMs) results in RuntimeError: Expected positional argument for parameter unsqueeze_default, but one was not passed in!. Reproducing requires applying https://github.com/pytorch/pytorch/pull/182182, https://github.com/pytorch/pytorch/pull/182644, and https://github.com/pytorch/pytorch/pull/183582 over 2.11. The script runs successfully when model_type = 'gpt2' and fails when model_type = 'llama'.

<details> <summary> Repro script, run on 2 gpus </summary>
import torch
import torch.distributed as dist
from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, Schedule1F1B
from transformers import AutoModelForCausalLM
import os

# Initialize torchrun's distributed environment
pp_group = dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])

# Assign this specific process to its designated GPU
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)

print(f"[Rank {rank}] Starting model init on {device}")

#model_type = 'gpt2'
model_type = 'llama'

if model_type == 'gpt2':
    # GPT2
    model_id = "openai-community/gpt2"

elif model_type == 'llama':
    # SmolLM (llama style, uses modeling_llama.py)
    model_id = "HuggingFaceTB/SmolLM2-135M"

device_count = torch.cuda.device_count()
model = AutoModelForCausalLM.from_pretrained(model_id)
print(model)

# kv cache is enabled by default and causes other issues with dynamo tracing/pipelining, disabling
model.config.use_cache = False

# run on 2 or more gpus
if model_type == 'gpt2':
    # GPT2
    split_spec = {f"transformer.h.5": SplitPoint.END,}

elif model_type == 'llama':
    # SmolLM/llama-style models
    split_spec = {f"model.layers.14": SplitPoint.END,}

model = model.to(device)

pipe_module = pipeline(
    module=model,
    mb_args=(),
    mb_kwargs={
        'input_ids': torch.zeros((1, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((1, 512), dtype=torch.bool).to(device),
    },
    split_spec=split_spec,
)

stage = pipe_module.build_stage(rank, device, pp_group)

# placeholder loss function
def loss_fn(outputs, other):
    if isinstance(outputs, tuple):
        return outputs[0]
    elif isinstance(outputs, torch.Tensor):
        logits = outputs
    else:
        logits = outputs['logits']
    loss = logits.sum() - other.sum()
    return loss

schedule = Schedule1F1B(stage, n_microbatches=4, loss_fn=loss_fn)

if rank == 0:
    inputs = {
        'input_ids': torch.zeros((4, 512), dtype=torch.long).to(device),
        'attention_mask': torch.ones((4, 512), dtype=torch.bool).to(device),
    }

    schedule.step(
        **inputs,
    )
elif rank == world_size - 1:
    losses = []
    output = schedule.step(target=torch.randn(4, 512).to(device), losses=losses)
    print(f"losses: {losses}")
else:
    schedule.step()
</details> <details> <summary> Full traceback </summary>
(pt) fyguan@beast:~/disk20$ torchrun --nproc-per-node 2 pipelining_model_backward.py 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] *****************************************
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0519 11:42:34.313000 53989 site-packages/torch/distributed/run.py:851] *****************************************
[Rank 1] Starting model init on cuda:1
[Rank 0] Starting model init on cuda:0
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 41777.16it/s]
Loading weights: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 38255.34it/s]
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=576, out_features=49152, bias=False)
)
/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 328, in placeholder
[rank0]:     return next(self.args_iter)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]: StopIteration

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 892, in forward_one_chunk
[rank0]:     output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 762, in forward_maybe_with_nosync
[rank0]:     out_val = self.submod(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 949, in call_wrapped
[rank0]:     return self._wrapped_call(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 461, in __call__
[rank0]:     raise e
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/graph_module.py", line 447, in __call__
[rank0]:     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "<eval_with_key>.474", line 6, in forward
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward
[rank0]:     return torch.fx.Interpreter(self, graph=self.graph).run(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 431, in call_module
[rank0]:     return submod(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/export/unflatten.py", line 216, in forward
[rank0]:     return torch.fx.Interpreter(self, graph=self.graph).run(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 200, in run
[rank0]:     self.env[node] = self.run_node(node)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 297, in run_node
[rank0]:     return getattr(self, n.op)(n.target, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/fx/interpreter.py", line 333, in placeholder
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: Expected positional argument for parameter unsqueeze_default, but one was not passed in!

[rank0]: While executing %unsqueeze_default : [num_users=1] = placeholder[target=unsqueeze_default]
[rank0]: Original traceback:
[rank0]: File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank0]:     outputs: BaseModelOutputWithPast = self.model(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 397, in forward
[rank0]:     position_ids = position_ids.unsqueeze(0)
[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: While executing %rotary_emb : [num_users=1] = call_module[target=rotary_emb](args = (), kwargs = {})
[rank0]: Original traceback:
[rank0]: None
[rank0]: Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/disk20/user/fyguan/pipelining_model_backward.py", line 79, in <module>
[rank0]:     schedule.step(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 604, in step
[rank0]:     self._step_microbatches(
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/schedules.py", line 873, in _step_microbatches
[rank0]:     output = self._stage.forward_one_chunk(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/pipelining/stage.py", line 900, in forward_one_chunk
[rank0]:     raise RuntimeError(exc_msg) from e
[rank0]: RuntimeError: 
[rank0]:             [Stage 0] failed to run forward:
[rank0]:             args: ()
[rank0]:             kwargs: {'input_ids': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.int64)', 'attention_mask': 'Tensor(torch.Size([1, 512]), grad=False, dtype=torch.bool)'}
[rank0]:             
[rank0]:[W519 11:42:51.791357236 ProcessGroupNCCL.cpp:1575] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0519 11:42:54.151000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:1012] Sending process 54075 closing signal SIGTERM
E0519 11:42:54.469000 53989 site-packages/torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 54074) of binary: /mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/python3.12
Traceback (most recent call last):
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/bin/torchrun", line 6, in <module>
    sys.exit(main())
             ^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 990, in main
    run(args)
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/run.py", line 981, in run
    elastic_launch(
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 170, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/disk20/user/fyguan/miniconda3/envs/pt/lib/python3.12/site-packages/torch/distributed/launcher/api.py", line 317, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
pipelining_model_backward.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2026-05-19_11:42:54
  host      : beast
  rank      : 1 (local_rank: 1)
  exitcode  : -15 (pid: 54075)
  error_file: <N/A>
  traceback : Signal 15 (SIGTERM) received by PID 54075
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-05-19_11:42:54
  host      : beast
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 54074)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
</details> <details> <summary> Versions </summary>
Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.13 | packaged by conda-forge | (main, Mar  5 2026, 16:50:00) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.18.20-fyguan-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: 
GPU models and configuration: 
GPU 0: NVIDIA H200 NVL
GPU 1: NVIDIA H200 NVL
GPU 2: NVIDIA H200 NVL
GPU 3: NVIDIA H200 NVL

Nvidia driver version: 590.48.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           52 bits physical, 57 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  256
On-line CPU(s) list:                     0-255
Vendor ID:                               AuthenticAMD
Model name:                              AMD EPYC 9555 64-Core Processor
CPU family:                              26
Model:                                   2
Thread(s) per core:                      2
Core(s) per socket:                      64
Socket(s):                               2
Stepping:                                1
Frequency boost:                         enabled
CPU(s) scaling MHz:                      84%
CPU max MHz:                             4410.8110
CPU min MHz:                             1210.8110
BogoMIPS:                                6391.45
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpuid_fault cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid cqm rdt_a avx512f avx512dq adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca avx512_vp2intersect flush_l1d debug_swap amd_lbr_pmc_freeze
Virtualization:                          AMD-V
L1d cache:                               6 MiB (128 instances)
L1i cache:                               4 MiB (128 instances)
L2 cache:                                128 MiB (128 instances)
L3 cache:                                512 MiB (16 instances)
NUMA node(s):                            2
NUMA node0 CPU(s):                       0-63,128-191
NUMA node1 CPU(s):                       64-127,192-255
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Mitigation; Reduced Speculation
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] numpy==2.4.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] torch==2.11.0+cu130
[pip3] torchaudio==2.11.0+cu130
[pip3] torchvision==0.26.0+cu130
[pip3] triton==3.6.0
[conda] cuda-nvrtc                 13.2.51          hecca717_0            conda-forge
[conda] libcufft                   12.2.0.37        hecca717_0            conda-forge
[conda] numpy                      2.3.5            pypi_0                pypi
[conda] nvidia-cublas              13.1.0.3         pypi_0                pypi
[conda] nvidia-cuda-cupti          13.0.85          pypi_0                pypi
[conda] nvidia-cuda-nvrtc          13.0.88          pypi_0                pypi
[conda] nvidia-cuda-runtime        13.0.96          pypi_0                pypi
[conda] nvidia-cudnn-cu13          9.19.0.56        pypi_0                pypi
[conda] nvidia-cufft               12.0.0.61        pypi_0                pypi
[conda] nvidia-curand              10.4.0.35        pypi_0                pypi
[conda] nvidia-cusolver            12.0.4.66        pypi_0                pypi
[conda] nvidia-cusparse            12.6.3.3         pypi_0                pypi
[conda] nvidia-cusparselt-cu13     0.8.0            pypi_0                pypi
[conda] nvidia-nccl-cu13           2.28.9           pypi_0                pypi
[conda] nvidia-nvjitlink           13.0.88          pypi_0                pypi
[conda] nvidia-nvtx                13.0.85          pypi_0                pypi
[conda] torch                      2.11.0+cu130     pypi_0                pypi
[conda] torchaudio                 2.11.0+cu130     pypi_0                pypi
[conda] torchvision                0.26.0+cu130     pypi_0                pypi
[conda] triton                     3.6.0            pypi_0                pypi
</details>

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [pipelining] PP fails when trying to pipeline any HF transformers model that uses RoPE [3 pull requests, 1 comments, 1 participants]