pytorch - 💡(How to fix) Fix `grouped fully_shard(list) fails when a grouped module is reused later in the same iteration`

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

I found a case where native composable FSDP2 fully_shard(list) fails if a grouped module is reused later in the same iteration.

The pattern is:

norm -> mtp(inner uses shared head) -> shared head

where shared head is the same module instance, called once inside mtp and once again outside.

If I wrap the tail as one grouped FSDP unit:

fully_shard([model.norm, model.mtp, model.head], mesh=mesh)

backward fails.

If I wrap the same modules separately:

fully_shard(model.norm, mesh=mesh)
fully_shard(model.mtp, mesh=mesh)
fully_shard(model.head, mesh=mesh)

it succeeds.

Error Message

Grouped mode fails in backward with:

RuntimeError: setStorage: sizes [128, 512], strides [512, 1], storage offset 0, and itemsize 4 requiring a storage size of 262144 are out of bounds for storage of size 0

Root Cause

I found a case where native composable FSDP2 fully_shard(list) fails if a grouped module is reused later in the same iteration.

The pattern is:

norm -> mtp(inner uses shared head) -> shared head

where shared head is the same module instance, called once inside mtp and once again outside.

If I wrap the tail as one grouped FSDP unit:

fully_shard([model.norm, model.mtp, model.head], mesh=mesh)

backward fails.

If I wrap the same modules separately:

fully_shard(model.norm, mesh=mesh)
fully_shard(model.mtp, mesh=mesh)
fully_shard(model.head, mesh=mesh)

it succeeds.

Fix Action

Fix / Workaround

CPU: Architecture: aarch64 CPU op-mode(s): 64-bit Byte Order: Little Endian CPU(s): 256 On-line CPU(s) list: 0-255 Vendor ID: HiSilicon Model name: Kunpeng-920 Model: 0 Thread(s) per core: 1 Core(s) per socket: 64 Socket(s): 4 Stepping: 0x1 BogoMIPS: 200.00 Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma dcpop asimddp asimdfhm ssbs L1d cache: 16 MiB (256 instances) L1i cache: 16 MiB (256 instances) L2 cache: 128 MiB (256 instances) L3 cache: 256 MiB (8 instances) NUMA node(s): 8 NUMA node0 CPU(s): 0-31 NUMA node1 CPU(s): 32-63 NUMA node2 CPU(s): 64-95 NUMA node3 CPU(s): 96-127 NUMA node4 CPU(s): 128-159 NUMA node5 CPU(s): 160-191 NUMA node6 CPU(s): 192-223 NUMA node7 CPU(s): 224-255 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; __user pointer sanitization Vulnerability Spectre v2: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Code Example

norm -> mtp(inner uses shared head) -> shared head

---

fully_shard([model.norm, model.mtp, model.head], mesh=mesh)

---

fully_shard(model.norm, mesh=mesh)
fully_shard(model.mtp, mesh=mesh)
fully_shard(model.head, mesh=mesh)

---

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp._fully_shard._fully_shard import fully_shard

class TinyRMSNorm(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.norm = nn.RMSNorm(hidden_size)

    def forward(self, x):
        return self.norm(x)

class SharedHead(nn.Module):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x):
        return self.linear(x)

class TinyMTP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.final_layernorm = TinyRMSNorm(hidden_size)
        self.latest_aux_logits = None

    def forward(self, x, shared_head):
        x = self.proj(x)
        x = self.final_layernorm(x)
        self.latest_aux_logits = shared_head(x)
        return x

class Model(nn.Module):
    def __init__(self, hidden_size=512, vocab_size=128):
        super().__init__()
        self.input_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.norm = TinyRMSNorm(hidden_size)
        self.mtp = TinyMTP(hidden_size)
        self.head = SharedHead(hidden_size, vocab_size)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.norm(x)
        x = self.mtp(x, self.head)
        main_logits = self.head(x)
        aux_logits = self.mtp.latest_aux_logits
        return main_logits.square().mean() + 0.3 * aux_logits.square().mean()

def apply_fsdp(model, mesh, mode):
    fully_shard(model.input_proj, mesh=mesh)
    if mode == "group":
        fully_shard([model.norm, model.mtp, model.head], mesh=mesh)
    else:
        fully_shard(model.norm, mesh=mesh)
        fully_shard(model.mtp, mesh=mesh)
        fully_shard(model.head, mesh=mesh)
    fully_shard(model, mesh=mesh)

def run(mode):
    dist.init_process_group("gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    mesh = init_device_mesh("cpu", (world_size,))
    model = Model().cpu()
    apply_fsdp(model, mesh, mode)
    optim = torch.optim.SGD(model.parameters(), lr=1e-3, foreach=True)
    x = torch.randn(1, 512)

    optim.zero_grad(set_to_none=True)
    loss = model(x)
    loss.backward()
    optim.step()

    print(f"rank={rank}, mode={mode}, loss={loss.item()}")
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    import sys
    run(sys.argv[1])

---

torchrun --standalone --nproc_per_node=2 repro.py group

---

torchrun --standalone --nproc_per_node=2 repro.py split

---

RuntimeError: setStorage: sizes [128, 512], strides [512, 1], storage offset 0, and itemsize 4 requiring a storage size of 262144 are out of bounds for storage of size 0
RAW_BUFFERClick to expand / collapse

Summary

I found a case where native composable FSDP2 fully_shard(list) fails if a grouped module is reused later in the same iteration.

The pattern is:

norm -> mtp(inner uses shared head) -> shared head

where shared head is the same module instance, called once inside mtp and once again outside.

If I wrap the tail as one grouped FSDP unit:

fully_shard([model.norm, model.mtp, model.head], mesh=mesh)

backward fails.

If I wrap the same modules separately:

fully_shard(model.norm, mesh=mesh)
fully_shard(model.mtp, mesh=mesh)
fully_shard(model.head, mesh=mesh)

it succeeds.

Minimal repro

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp._fully_shard._fully_shard import fully_shard

class TinyRMSNorm(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.norm = nn.RMSNorm(hidden_size)

    def forward(self, x):
        return self.norm(x)

class SharedHead(nn.Module):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden_size, vocab_size, bias=False)

    def forward(self, x):
        return self.linear(x)

class TinyMTP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.final_layernorm = TinyRMSNorm(hidden_size)
        self.latest_aux_logits = None

    def forward(self, x, shared_head):
        x = self.proj(x)
        x = self.final_layernorm(x)
        self.latest_aux_logits = shared_head(x)
        return x

class Model(nn.Module):
    def __init__(self, hidden_size=512, vocab_size=128):
        super().__init__()
        self.input_proj = nn.Linear(hidden_size, hidden_size, bias=False)
        self.norm = TinyRMSNorm(hidden_size)
        self.mtp = TinyMTP(hidden_size)
        self.head = SharedHead(hidden_size, vocab_size)

    def forward(self, x):
        x = self.input_proj(x)
        x = self.norm(x)
        x = self.mtp(x, self.head)
        main_logits = self.head(x)
        aux_logits = self.mtp.latest_aux_logits
        return main_logits.square().mean() + 0.3 * aux_logits.square().mean()

def apply_fsdp(model, mesh, mode):
    fully_shard(model.input_proj, mesh=mesh)
    if mode == "group":
        fully_shard([model.norm, model.mtp, model.head], mesh=mesh)
    else:
        fully_shard(model.norm, mesh=mesh)
        fully_shard(model.mtp, mesh=mesh)
        fully_shard(model.head, mesh=mesh)
    fully_shard(model, mesh=mesh)

def run(mode):
    dist.init_process_group("gloo")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    mesh = init_device_mesh("cpu", (world_size,))
    model = Model().cpu()
    apply_fsdp(model, mesh, mode)
    optim = torch.optim.SGD(model.parameters(), lr=1e-3, foreach=True)
    x = torch.randn(1, 512)

    optim.zero_grad(set_to_none=True)
    loss = model(x)
    loss.backward()
    optim.step()

    print(f"rank={rank}, mode={mode}, loss={loss.item()}")
    dist.barrier()
    dist.destroy_process_group()

if __name__ == "__main__":
    import sys
    run(sys.argv[1])

Repro commands

Fails:

torchrun --standalone --nproc_per_node=2 repro.py group

Succeeds:

torchrun --standalone --nproc_per_node=2 repro.py split

Error

Grouped mode fails in backward with:

RuntimeError: setStorage: sizes [128, 512], strides [512, 1], storage offset 0, and itemsize 4 requiring a storage size of 262144 are out of bounds for storage of size 0

Why this seems like a bug

The grouped fully_shard([a, b, ...]) docs explicitly mention that grouped modules may run only partially in the main forward and be called later in the same iteration (chunked-loss style usage).

I could not find a documented restriction saying that a grouped module cannot be reused/re-entered later in the same iteration.

Since:

  • grouped mode fails,
  • split mode succeeds,
  • and the usage pattern seems close to the documented grouped-forward semantics,

this looks like a bug (or at least an undocumented limitation) in grouped fully_shard(list).


### Versions

Collecting environment information...
PyTorch version: 2.9.0+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: openEuler 22.03 (LTS-SP1) (aarch64)
GCC version: (GCC) 10.3.1
Clang version: 12.0.1 (openEuler 12.0.1-4.oe2203sp1 42e43d83d1d5f202dbb00b5df9d0a23d6c326edc)
CMake version: version 3.25.3
Libc version: glibc-2.34

Python version: 3.10.0 | packaged by conda-forge | (default, Nov 20 2021, 02:50:31) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.0-136.108.0.188.oe2203sp1.aarch64-aarch64-with-glibc2.34
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture:                       aarch64
CPU op-mode(s):                     64-bit
Byte Order:                         Little Endian
CPU(s):                             256
On-line CPU(s) list:                0-255
Vendor ID:                          HiSilicon
Model name:                         Kunpeng-920
Model:                              0
Thread(s) per core:                 1
Core(s) per socket:                 64
Socket(s):                          4
Stepping:                           0x1
BogoMIPS:                           200.00
Flags:                              fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma dcpop asimddp asimdfhm ssbs
L1d cache:                          16 MiB (256 instances)
L1i cache:                          16 MiB (256 instances)
L2 cache:                           128 MiB (256 instances)
L3 cache:                           256 MiB (8 instances)
NUMA node(s):                       8
NUMA node0 CPU(s):                  0-31
NUMA node1 CPU(s):                  32-63
NUMA node2 CPU(s):                  64-95
NUMA node3 CPU(s):                  96-127
NUMA node4 CPU(s):                  128-159
NUMA node5 CPU(s):                  160-191
NUMA node6 CPU(s):                  192-223
NUMA node7 CPU(s):                  224-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; __user pointer sanitization
Vulnerability Spectre v2:           Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] torch==2.9.0+cpu
[conda] numpy                                       2.2.6            pypi_0                pypi
[conda] torch                                       2.9.0+cpu        pypi_0                pypi


cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @ppwwyyxx

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `grouped fully_shard(list) fails when a grouped module is reused later in the same iteration`