transformers - ✅(Solved) Fix Expose static_graph DDP flag via TrainingArguments [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45518Fetched 2026-04-20 11:58:50
View on GitHub
Comments
0
Participants
1
Timeline
2
Reactions
0
Author
Participants
Timeline (top)
cross-referenced ×1referenced ×1

Error Message

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. … Parameter indices which did not receive grad for rank 0: 0 1

Root Cause

With static_graph=True passed to DDP this model would train cleanly. Today that requires monkey-patching DistributedDataParallel.__init__ because TrainingArguments has no way to request it.

Fix Action

Fix / Workaround

The common workaround is ddp_find_unused_parameters=True, but that forces DDP to traverse the autograd graph on every iteration to find unused params — a measurable per-step cost. static_graph=True is the performance-optimal alternative (PyTorch's own note: "potentially improve performance when there are unused parameters"): DDP records the participating-parameter set on iter 1 and assumes it's stable thereafter, so it neither errors nor pays the find-unused traversal cost on subsequent iters. It also safely handles models with re-used modules in forward (e.g. diffusion / flow-matching heads that iterate over the same layers across integration steps — PyTorch's "Reentrant backwards" supported-use case).

Today users who want static_graph=True must monkey-patch torch.nn.parallel.DistributedDataParallel.__init__ (e.g. via a sitecustomize.py shim) or subclass Trainer and override _wrap_model — neither portable.

With static_graph=True passed to DDP this model would train cleanly. Today that requires monkey-patching DistributedDataParallel.__init__ because TrainingArguments has no way to request it.

PR fix notes

PR #45519: [Trainer] Add ddp_static_graph option

Description (problem / solution / changelog)

What does this PR do?

Exposes PyTorch DDP's static_graph flag via a new ddp_static_graph: Optional[bool] field on TrainingArguments, forwarded through Trainer._build_accelerator_args into Accelerate's DistributedDataParallelKwargs (which already supports it; only the Transformers-side plumbing was missing).

This completes the set of DDP flags already partially exposed on TrainingArguments (ddp_find_unused_parameters, ddp_bucket_cap_mb, ddp_broadcast_buffers). Today a user can configure nearly everything about DDP except static_graph, and today's only workarounds are monkey-patching DistributedDataParallel.__init__ via a sitecustomize.py shim or subclassing Trainer to override _wrap_model — neither portable.

Fixes #45518

Why

Per PyTorch's DDP docs, static_graph=True relaxes several DDP reducer constraints for users who can guarantee a stable graph across iterations: "Reentrant backwards", "Activation checkpointing when model has unused parameters", "There are model parameters that are outside of forward function", and "Potentially improve performance when there are unused parameters."

A common HF Trainer scenario where this matters: a model with trainable parameters that don't contribute to loss on every iteration (e.g. frozen submodules, or multi-head models where only one head is trained). Under DDP with ddp_find_unused_parameters=False (the Trainer default), such a model fails at iter 1 with:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss.

The common workaround is ddp_find_unused_parameters=True, but that forces DDP to traverse the autograd graph on every iteration to find unused params — a measurable per-step cost. static_graph=True is the performance-optimal alternative (PyTorch's own note: "potentially improve performance when there are unused parameters"): DDP records the participating-parameter set on iter 1 and assumes it's stable thereafter.

The earlier blocker that once made static_graph=True unsafe with HF models (ModelOutput subclasses not registered as pytree nodes) was fixed in #25358 (closed #25357, merged 2023-08) and is still live in src/transformers/utils/generic.py — where the repo itself documents static_graph=True safety with ModelOutput in a docstring.

Changes

  • src/transformers/training_args.py — new ddp_static_graph: bool | None = field(default=None, …), mirroring the ddp_broadcast_buffers pattern. Docstring entry explains the supported-use-cases surface and caveats (DDP-only; incompatible with re-entrant activation checkpointing; requires stable graph).
  • src/transformers/trainer.py (_build_accelerator_args) — one new conditional following the existing if self.args.ddp_* is not None: pattern for bucket_cap_mb and broadcast_buffers. When the flag is None (default), the kwarg is never added to ddp_kwargs, so DistributedDataParallelKwargs' own default (False) applies. Strictly additive; no existing behavior changes.
  • tests/trainer/test_trainer.py — new TrainerDDPKwargsTest class with three tests:
    • ddp_static_graph=True → handler has static_graph=True (positive).
    • ddp_static_graph=False → handler has static_graph=False (positive).
    • ddp_static_graph=None (default) → handler preserves Accelerate's default False (regression guard — the conditional in _build_accelerator_args must NOT leak the kwarg when unset).

All three pass locally.

Interaction caveats (documented in help text)

  • DDP-only. The field has no effect under FSDP or DeepSpeed (the conditional only fires on the DDP path, matching the other ddp_* fields).
  • Gradient checkpointing. Using static_graph=True alongside re-entrant activation checkpointing (use_reentrant=True) is unsafe per PyTorch; the docstring warns. Non-reentrant checkpointing (use_reentrant=False) is fine.
  • Requires stable graph. Modules with data-dependent control flow that changes which parameters are touched per iteration are incompatible with static_graph=True by PyTorch's own contract.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. — #45518
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc — Trainer / Accelerate integration.

Changed files

  • src/transformers/trainer.py (modified, +2/-0)
  • src/transformers/training_args.py (modified, +19/-0)
  • tests/trainer/test_trainer.py (modified, +41/-0)

Code Example

# repro_static_graph.py
"""Minimal reproducer for DDP 'params not used' failure under HF Trainer with the
default ddp_find_unused_parameters=False. static_graph=True would fix this
without the extra per-iteration cost of find_unused=True."""
import os
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers.utils import ModelOutput


@dataclass
class MiniOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    logits: Optional[torch.Tensor] = None


class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Trainable but never used in forward. DDP tracks it and expects a
        # gradient; under ddp_find_unused_parameters=False this errors on iter 1.
        self.orphan = nn.Linear(32, 32)
        self.head = nn.Linear(32, 32)

    def forward(self, inputs=None, labels=None, **kwargs):
        h = self.head(inputs)
        loss = ((h - labels) ** 2).mean() if labels is not None else h.sum()
        return MiniOutput(loss=loss, logits=h)


class MiniDataset(Dataset):
    def __len__(self): return 16
    def __getitem__(self, idx):
        return {"inputs": torch.randn(32), "labels": torch.randn(32)}


def main():
    args = TrainingArguments(
        output_dir="/tmp/repro_out",
        per_device_train_batch_size=2,
        max_steps=3,
        save_strategy="no",
        report_to=[],
        ddp_backend="gloo",
        ddp_find_unused_parameters=False,  # HF default; what this issue is about
        use_cpu=True,
    )
    Trainer(model=MiniModel(), args=args, train_dataset=MiniDataset()).train()


if __name__ == "__main__":
    main()

---

docker run --rm -v "$PWD/repro_static_graph.py:/w/r.py" python:3.11-slim bash -c '
  pip install --quiet transformers accelerate torch
  torchrun --nproc-per-node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 /w/r.py
'

---

RuntimeError: Expected to have finished reduction in the prior iteration
before starting a new one. This error indicates that your module has
parameters that were not used in producing loss. 
Parameter indices which did not receive grad for rank 0: 0 1
RAW_BUFFERClick to expand / collapse

Feature request

Add a ddp_static_graph: Optional[bool] field to TrainingArguments (mirroring the existing ddp_broadcast_buffers pattern) and forward it through Trainer._build_accelerator_args into Accelerate's DistributedDataParallelKwargs.static_graph. Defaults to None; when unset, Accelerate's own default (False) applies — strictly additive, no existing behavior changes.

Motivation

TrainingArguments currently exposes ddp_find_unused_parameters, ddp_bucket_cap_mb, and ddp_broadcast_buffers — but not static_graph, even though Accelerate's DistributedDataParallelKwargs has static_graph: bool = False already, and _build_accelerator_args has a straightforward conditional-set pattern that extends to one more kwarg.

Per PyTorch's DistributedDataParallel docs, static_graph=True tells DDP the trained graph is static: "The set of used and unused parameters will not change during the whole training loop", and "how the graph is trained will not change during the whole training loop (meaning there is no control flow depending on iterations)." The same section lists the feature as supporting, among other things, "Reentrant backwards", "Activation checkpointing when model has unused parameters", "There are model parameters that are outside of forward function", and "Potentially improve performance when there are unused parameters."

That matches a common HF Trainer scenario: a model with trainable parameters that don't contribute to loss on every iteration (e.g. frozen submodules, or multi-head models where only one head is trained). Under DDP with ddp_find_unused_parameters=False (the HF Trainer default), such a model fails at iter 1 with:

RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss.

The common workaround is ddp_find_unused_parameters=True, but that forces DDP to traverse the autograd graph on every iteration to find unused params — a measurable per-step cost. static_graph=True is the performance-optimal alternative (PyTorch's own note: "potentially improve performance when there are unused parameters"): DDP records the participating-parameter set on iter 1 and assumes it's stable thereafter, so it neither errors nor pays the find-unused traversal cost on subsequent iters. It also safely handles models with re-used modules in forward (e.g. diffusion / flow-matching heads that iterate over the same layers across integration steps — PyTorch's "Reentrant backwards" supported-use case).

Today users who want static_graph=True must monkey-patch torch.nn.parallel.DistributedDataParallel.__init__ (e.g. via a sitecustomize.py shim) or subclass Trainer and override _wrap_model — neither portable.

The original blocker that once made static_graph=True unsafe with HF models — ModelOutput subclasses not registered as pytree nodes — was fixed in #25358 (closed #25357, merged 2023-08). The registration is still live, and the repo itself explicitly documents static_graph=True safety at utils/generic.py:383-384:

"This is necessary to synchronize gradients when using torch.nn.parallel.DistributedDataParallel with static_graph=True with modules that output ModelOutput subclasses."

So the feature is already advertised as supported in the repo — this issue is just about plumbing the flag through TrainingArguments_build_accelerator_args.

Caveats to document in the help text: DDP-only (no effect under FSDP/DeepSpeed); incompatible with re-entrant activation checkpointing (use_reentrant=True); requires a stable computation graph across iterations per PyTorch's own static_graph contract.

Reproducer

A minimal CPU-only reproducer (no GPU required, runs in any Python 3.11 image):

# repro_static_graph.py
"""Minimal reproducer for DDP 'params not used' failure under HF Trainer with the
default ddp_find_unused_parameters=False. static_graph=True would fix this
without the extra per-iteration cost of find_unused=True."""
import os
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from transformers.utils import ModelOutput


@dataclass
class MiniOutput(ModelOutput):
    loss: Optional[torch.Tensor] = None
    logits: Optional[torch.Tensor] = None


class MiniModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Trainable but never used in forward. DDP tracks it and expects a
        # gradient; under ddp_find_unused_parameters=False this errors on iter 1.
        self.orphan = nn.Linear(32, 32)
        self.head = nn.Linear(32, 32)

    def forward(self, inputs=None, labels=None, **kwargs):
        h = self.head(inputs)
        loss = ((h - labels) ** 2).mean() if labels is not None else h.sum()
        return MiniOutput(loss=loss, logits=h)


class MiniDataset(Dataset):
    def __len__(self): return 16
    def __getitem__(self, idx):
        return {"inputs": torch.randn(32), "labels": torch.randn(32)}


def main():
    args = TrainingArguments(
        output_dir="/tmp/repro_out",
        per_device_train_batch_size=2,
        max_steps=3,
        save_strategy="no",
        report_to=[],
        ddp_backend="gloo",
        ddp_find_unused_parameters=False,  # HF default; what this issue is about
        use_cpu=True,
    )
    Trainer(model=MiniModel(), args=args, train_dataset=MiniDataset()).train()


if __name__ == "__main__":
    main()

Run (no GPU needed):

docker run --rm -v "$PWD/repro_static_graph.py:/w/r.py" python:3.11-slim bash -c '
  pip install --quiet transformers accelerate torch
  torchrun --nproc-per-node=2 --rdzv-backend=c10d --rdzv-endpoint=localhost:0 /w/r.py
'

Expected output (verified on transformers==4.57.1, torch==2.8.0):

RuntimeError: Expected to have finished reduction in the prior iteration
before starting a new one. This error indicates that your module has
parameters that were not used in producing loss. …
Parameter indices which did not receive grad for rank 0: 0 1

With static_graph=True passed to DDP this model would train cleanly. Today that requires monkey-patching DistributedDataParallel.__init__ because TrainingArguments has no way to request it.

Your contribution

Yes — I plan to open a PR immediately. The change is small: one field in training_args.py mirroring ddp_broadcast_buffers, one conditional in _build_accelerator_args following the existing if self.args.ddp_* is not None: ddp_kwargs[...] = ... pattern, a docs entry covering the caveats above, and positive + negative regression unit tests in tests/trainer/test_trainer.py. I've read CONTRIBUTING.md.

extent analysis

TL;DR

Add a ddp_static_graph field to TrainingArguments and forward it to Accelerate's DistributedDataParallelKwargs to enable static graph optimization in DDP.

Guidance

  • Add a ddp_static_graph: Optional[bool] field to TrainingArguments with a default value of None.
  • Update _build_accelerator_args in Trainer to forward ddp_static_graph to DistributedDataParallelKwargs if it is not None.
  • Document the caveats of using static_graph=True, including its incompatibility with re-entrant activation checkpointing and requirement for a stable computation graph.
  • Add regression unit tests to tests/trainer/test_trainer.py to verify the correctness of the new feature.

Example

class TrainingArguments:
    # ...
    ddp_static_graph: Optional[bool] = None

def _build_accelerator_args(self):
    # ...
    if self.args.ddp_static_graph is not None:
        ddp_kwargs["static_graph"] = self.args.ddp_static_graph

Notes

The proposed change is a straightforward extension of the existing ddp_broadcast_buffers pattern in TrainingArguments and _build_accelerator_args. However, it is essential to thoroughly test the new feature to ensure its correctness and compatibility with different use cases.

Recommendation

Apply the proposed workaround by adding the ddp_static_graph field to TrainingArguments and forwarding it to DistributedDataParallelKwargs. This change enables the static graph optimization in DDP, which can improve performance when there are unused parameters in the model.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING