pytorch - 💡(How to fix) Fix FSDP2 with ignore_params does not work with foreach optimizers [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#183207Fetched 2026-05-11 03:12:43
View on GitHub
Comments
0
Participants
1
Timeline
62
Reactions
0
Author
Participants
Timeline (top)
mentioned ×28subscribed ×28labeled ×6

Code Example

#!/usr/bin/env python3

import os

import torch
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DeviceMesh, DTensor


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.l = nn.Linear(4, 4)

    def forward(self, x):
        return self.l(x).sum()


def main():
    torch.distributed.init_process_group()
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    use_cuda = torch.cuda.is_available()
    device = torch.device(f'cuda:{local_rank}' if use_cuda else 'cpu')
    if use_cuda:
        torch.cuda.set_device(device)

    model = Model().to(device)
    ignored = {model.l.bias}

    mesh = DeviceMesh.from_group(
        torch.distributed.GroupMember.WORLD,
        'cuda' if use_cuda else 'cpu',
    )
    fully_shard(model, mesh=mesh, ignored_params=ignored, reshard_after_forward=False)

    foreach = bool(int(os.environ.get('FOREACH', '1')))
    opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, foreach=foreach)
    x = torch.randn(2, 4, device=device)
    loss = model(x)
    loss.backward()

    if rank == 0:
        print(f'Calling optimizer.step() with foreach={foreach} ...')
    opt.step()
    if rank == 0:
        print('optimizer.step() succeeded')

    torch.distributed.destroy_process_group()


if __name__ == '__main__':
    main()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

#!/usr/bin/env python3

import os

import torch
from torch import nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DeviceMesh, DTensor


class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.l = nn.Linear(4, 4)

    def forward(self, x):
        return self.l(x).sum()


def main():
    torch.distributed.init_process_group()
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    use_cuda = torch.cuda.is_available()
    device = torch.device(f'cuda:{local_rank}' if use_cuda else 'cpu')
    if use_cuda:
        torch.cuda.set_device(device)

    model = Model().to(device)
    ignored = {model.l.bias}

    mesh = DeviceMesh.from_group(
        torch.distributed.GroupMember.WORLD,
        'cuda' if use_cuda else 'cpu',
    )
    fully_shard(model, mesh=mesh, ignored_params=ignored, reshard_after_forward=False)

    foreach = bool(int(os.environ.get('FOREACH', '1')))
    opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, foreach=foreach)
    x = torch.randn(2, 4, device=device)
    loss = model(x)
    loss.backward()

    if rank == 0:
        print(f'Calling optimizer.step() with foreach={foreach} ...')
    opt.step()
    if rank == 0:
        print('optimizer.step() succeeded')

    torch.distributed.destroy_process_group()


if __name__ == '__main__':
    main()

Run:

torchrun --standalone --nproc_per_node=2 repro_fsdp2_ignored_params_foreach.py

Got : RuntimeError: aten._foreach_add_.List: got mixed torch.Tensor and DTensor

Expected fix: group tensors and dtensors separately in foreach optimizer.

Versions

2.11

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @tianyu-l @XilunWu @SherlockNoMad

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix FSDP2 with ignore_params does not work with foreach optimizers [1 participants]