pytorch - ✅(Solved) Fix [FSDP2] Inplace operation on a view tensor can drop the fsdp pre-backward hook on it. [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181832Fetched 2026-04-30 06:18:17
View on GitHub
Comments
1
Participants
2
Timeline
58
Reactions
1
Author
Participants
Timeline (top)
mentioned ×25subscribed ×25labeled ×6commented ×1

Error Message

"""Minimal reproduction of the FSDP2 pre-backward hook bug: When a downstream module performs an in-place op on a view of an FSDP2 module's output, the pre-backward hook registered by FSDP2 is silently dropped. As a result:

  1. all-gather (unshard) is never triggered before backward.
  2. With reshard_after_forward=True, saved tensors point to a storage that has been resize_(0)'d, leading to: RuntimeError: setStorage: ... are out of bounds for storage of size 0

Run with (single node, 2 GPUs): torchrun --nproc_per_node=2 repro_fsdp2_inplace_hook_bug.py Or single GPU: torchrun --nproc_per_node=1 repro_fsdp2_inplace_hook_bug.py

Toggle USE_INPLACE / USE_VIEW / RESHARD_AFTER_FORWARD at the top to explore the four cases. """ import os import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import fully_shard, FSDPModule

---------------------------------------------------------------------------

Toggles to explore behavior:

USE_INPLACE = True # True: out += y (triggers bug); False: out = out + y USE_VIEW = True # True: squeeze (view, triggers CopySlices); False: clone (no view) RESHARD_AFTER_FORWARD = True # True: storage gets resize_(0)'d; False: param stays unsharded

---------------------------------------------------------------------------

class _MatmulSaveView(torch.autograd.Function): """Custom matmul whose backward needs the EXACT saved view of weight.

Mimics `torch._grouped_mm`: in backward, reconstruct a strided view from
the saved weight's storage. If the storage was resize_(0)'d by FSDP
reshard, this fails with:
    RuntimeError: setStorage: ... out of bounds for storage of size 0
"""
@staticmethod
def forward(ctx, x, w_view):
    ctx.save_for_backward(x, w_view)
    return x @ w_view

@staticmethod
def backward(ctx, grad_out):
    x, w_view = ctx.saved_tensors
    rank = int(os.environ.get("RANK", "0"))
    storage_nbytes = w_view.untyped_storage().nbytes()
    expected_nbytes = w_view.numel() * w_view.element_size()
    if rank == 0:
        print(f"[BWD] _MatmulSaveView.backward: w_view.shape={tuple(w_view.shape)}, "
              f"storage.nbytes={storage_nbytes}, expected={expected_nbytes}, "
              f"storage_freed={storage_nbytes == 0}", flush=True)
    # Mimic the strict TORCH_CHECK inside `_grouped_mm` backward: if FSDP
    # already resharded the parameter (storage resize_(0)), abort loudly.
    # PyTorch's bare `Tensor.set_` does NOT validate this and would silently
    # build a dangling view, hiding the bug behind garbage gradients.
    assert storage_nbytes >= expected_nbytes, (
        f"setStorage: sizes {tuple(w_view.shape)}, strides {w_view.stride()}, "
        f"requiring a storage size of {expected_nbytes} are out of bounds "
        f"for storage of size {storage_nbytes}"
    )
    # Force re-materialization via storage (same pattern as _grouped_mm)
    w_strided = torch.empty(0, dtype=w_view.dtype, device=w_view.device)
    w_strided.set_(w_view.untyped_storage(), 0, w_view.shape, w_view.stride())
    grad_x = grad_out @ w_strided.t()
    grad_w = x.t() @ grad_out
    return grad_x, grad_w

class InnerModule(nn.Module): """FSDP-wrapped leaf module. Its forward output is what the bug hits.""" def init(self, dim): super().init() self.weight = nn.Parameter(torch.randn(dim, dim) / (dim ** 0.5)) self.dim = dim

def forward(self, x):
    out = _MatmulSaveView.apply(x, self.weight)
    if USE_VIEW:
        out = out.unsqueeze(1).squeeze(1)   # view → triggers CopySlices when in-place hits
    else:
        out = out.clone()                    # break view chain
    return out

class OuterModule(nn.Module): """Mimics Qwen3_5MoeSparseMoeBlock.forward: in-place += on FSDP output.""" def init(self, dim): super().init() self.inner = InnerModule(dim) self.shared = nn.Linear(dim, dim, bias=False)

def forward(self, x):
    expert_output = self.inner(x)          # FSDP-wrapped output, possibly a view
    shared_output = self.shared(x)
    if USE_INPLACE:
        expert_output += shared_output     # ⚠️ in-place on view → CopySlices
    else:
        expert_output = expert_output + shared_output
    return expert_output.sum()

def main(): rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group(backend=backend, rank=rank, world_size=world_size) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") else: device = torch.device("cpu")

dim = 4096
model = OuterModule(dim).to(device)

# Apply FSDP2 only to the inner module (the one whose output gets += into).
fully_shard(model.inner, reshard_after_forward=RESHARD_AFTER_FORWARD)

# ---------- Probe: register a tensor hook on inner output to see if it fires ----------
# We patch InnerModule.forward to register a hook on its return value.
orig_inner_forward = InnerModule.forward

def probed_inner_forward(self, x):
    out = orig_inner_forward(self, x)
    if rank == 0:
        print(f"[probe] inner output: requires_grad={out.requires_grad}, "
              f"is_view={out.is_view() if hasattr(out, 'is_view') else out._base is not None}, "
              f"grad_fn={out.grad_fn}", flush=True)

    def _user_hook(grad):
        if rank == 0:
            print("[probe] >>> tensor.register_hook on inner output FIRED <<<", flush=True)
        return grad
    if out.requires_grad:
        out.register_hook(_user_hook)
    return out

InnerModule.forward = probed_inner_forward

# ---------- Run forward + backward ----------
x = torch.randn(8, dim, device=device, requires_grad=False)
if rank == 0:
    print(f"\n=== Config ===")
    print(f"  USE_INPLACE           = {USE_INPLACE}")
    print(f"  USE_VIEW              = {USE_VIEW}")
    print(f"  RESHARD_AFTER_FORWARD = {RESHARD_AFTER_FORWARD}")
    print(f"==============\n", flush=True)

try:
    loss = model(x)
    if rank == 0:
        print(f"[main] forward done, loss={loss.item():.4f}", flush=True)
    loss.backward()
    if rank == 0:
        print(f"[main] backward SUCCEEDED", flush=True)
except RuntimeError as e:
    if rank == 0:
        print(f"[main] backward FAILED: {e}", flush=True)

dist.destroy_process_group()

if name == "main": main()

Fix Action

Fix / Workaround

---------- Probe: register a tensor hook on inner output to see if it fires ----------

# We patch InnerModule.forward to register a hook on its return value.
orig_inner_forward = InnerModule.forward

PR fix notes

PR #181850: [FSDP2] warn when forward output is a view tensor

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #181850

fixes #181832

In-place ops on a view of an FSDP2 forward output silently drop the pre-backward hook (autograd view-rewrite orphans the Node), causing backward to skip all-gather and fail. See #181832.

Add a UserWarning when any grad-requiring forward output is a view, naming the wrapped module's class. Default filter dedups → once per rank. Actual fix (wrap via autograd.Function + clone) left to a follow-up.

Co-authored-by: ljy-gh [email protected]

Changed files

  • torch/distributed/fsdp/_fully_shard/_fsdp_state.py (modified, +12/-0)

Code Example

"""Minimal reproduction of the FSDP2 pre-backward hook bug:
When a downstream module performs an in-place op on a view of an FSDP2
module's output, the pre-backward hook registered by FSDP2 is silently
dropped. As a result:
  1. all-gather (unshard) is never triggered before backward.
  2. With reshard_after_forward=True, saved tensors point to a storage
     that has been resize_(0)'d, leading to:
        RuntimeError: setStorage: ... are out of bounds for storage of size 0

Run with (single node, 2 GPUs):
    torchrun --nproc_per_node=2 repro_fsdp2_inplace_hook_bug.py
Or single GPU:
    torchrun --nproc_per_node=1 repro_fsdp2_inplace_hook_bug.py

Toggle USE_INPLACE / USE_VIEW / RESHARD_AFTER_FORWARD at the top to
explore the four cases.
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, FSDPModule

# ---------------------------------------------------------------------------
# Toggles to explore behavior:
USE_INPLACE = True            # True: out += y (triggers bug);  False: out = out + y
USE_VIEW    = True            # True: squeeze (view, triggers CopySlices); False: clone (no view)
RESHARD_AFTER_FORWARD = True  # True: storage gets resize_(0)'d; False: param stays unsharded
# ---------------------------------------------------------------------------


class _MatmulSaveView(torch.autograd.Function):
    """Custom matmul whose backward needs the EXACT saved view of weight.

    Mimics `torch._grouped_mm`: in backward, reconstruct a strided view from
    the saved weight's storage. If the storage was resize_(0)'d by FSDP
    reshard, this fails with:
        RuntimeError: setStorage: ... out of bounds for storage of size 0
    """
    @staticmethod
    def forward(ctx, x, w_view):
        ctx.save_for_backward(x, w_view)
        return x @ w_view

    @staticmethod
    def backward(ctx, grad_out):
        x, w_view = ctx.saved_tensors
        rank = int(os.environ.get("RANK", "0"))
        storage_nbytes = w_view.untyped_storage().nbytes()
        expected_nbytes = w_view.numel() * w_view.element_size()
        if rank == 0:
            print(f"[BWD] _MatmulSaveView.backward: w_view.shape={tuple(w_view.shape)}, "
                  f"storage.nbytes={storage_nbytes}, expected={expected_nbytes}, "
                  f"storage_freed={storage_nbytes == 0}", flush=True)
        # Mimic the strict TORCH_CHECK inside `_grouped_mm` backward: if FSDP
        # already resharded the parameter (storage resize_(0)), abort loudly.
        # PyTorch's bare `Tensor.set_` does NOT validate this and would silently
        # build a dangling view, hiding the bug behind garbage gradients.
        assert storage_nbytes >= expected_nbytes, (
            f"setStorage: sizes {tuple(w_view.shape)}, strides {w_view.stride()}, "
            f"requiring a storage size of {expected_nbytes} are out of bounds "
            f"for storage of size {storage_nbytes}"
        )
        # Force re-materialization via storage (same pattern as _grouped_mm)
        w_strided = torch.empty(0, dtype=w_view.dtype, device=w_view.device)
        w_strided.set_(w_view.untyped_storage(), 0, w_view.shape, w_view.stride())
        grad_x = grad_out @ w_strided.t()
        grad_w = x.t() @ grad_out
        return grad_x, grad_w


class InnerModule(nn.Module):
    """FSDP-wrapped leaf module. Its forward output is what the bug hits."""
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim, dim) / (dim ** 0.5))
        self.dim = dim

    def forward(self, x):
        out = _MatmulSaveView.apply(x, self.weight)
        if USE_VIEW:
            out = out.unsqueeze(1).squeeze(1)   # view → triggers CopySlices when in-place hits
        else:
            out = out.clone()                    # break view chain
        return out


class OuterModule(nn.Module):
    """Mimics Qwen3_5MoeSparseMoeBlock.forward: in-place += on FSDP output."""
    def __init__(self, dim):
        super().__init__()
        self.inner = InnerModule(dim)
        self.shared = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        expert_output = self.inner(x)          # FSDP-wrapped output, possibly a view
        shared_output = self.shared(x)
        if USE_INPLACE:
            expert_output += shared_output     # ⚠️ in-place on view → CopySlices
        else:
            expert_output = expert_output + shared_output
        return expert_output.sum()


def main():
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cpu")

    dim = 4096
    model = OuterModule(dim).to(device)

    # Apply FSDP2 only to the inner module (the one whose output gets += into).
    fully_shard(model.inner, reshard_after_forward=RESHARD_AFTER_FORWARD)

    # ---------- Probe: register a tensor hook on inner output to see if it fires ----------
    # We patch InnerModule.forward to register a hook on its return value.
    orig_inner_forward = InnerModule.forward

    def probed_inner_forward(self, x):
        out = orig_inner_forward(self, x)
        if rank == 0:
            print(f"[probe] inner output: requires_grad={out.requires_grad}, "
                  f"is_view={out.is_view() if hasattr(out, 'is_view') else out._base is not None}, "
                  f"grad_fn={out.grad_fn}", flush=True)

        def _user_hook(grad):
            if rank == 0:
                print("[probe] >>> tensor.register_hook on inner output FIRED <<<", flush=True)
            return grad
        if out.requires_grad:
            out.register_hook(_user_hook)
        return out

    InnerModule.forward = probed_inner_forward

    # ---------- Run forward + backward ----------
    x = torch.randn(8, dim, device=device, requires_grad=False)
    if rank == 0:
        print(f"\n=== Config ===")
        print(f"  USE_INPLACE           = {USE_INPLACE}")
        print(f"  USE_VIEW              = {USE_VIEW}")
        print(f"  RESHARD_AFTER_FORWARD = {RESHARD_AFTER_FORWARD}")
        print(f"==============\n", flush=True)

    try:
        loss = model(x)
        if rank == 0:
            print(f"[main] forward done, loss={loss.item():.4f}", flush=True)
        loss.backward()
        if rank == 0:
            print(f"[main] backward SUCCEEDED", flush=True)
    except RuntimeError as e:
        if rank == 0:
            print(f"[main] backward FAILED: {e}", flush=True)

    dist.destroy_process_group()


if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Hi, I found a possible issue in FSDP2. The problem occurs when all of the following conditions are met:

  1. An FSDP pre-backward hook is registered on a view tensor.
  2. An in-place operation is performed on that view tensor.
  3. reshard_after_forward is enabled.

In this case, the hook registered on the view tensor appears to be lost, so the pre-backward function is not called during backward. After forward, the storage size of an all-gathered parameter is set to 0. During backward, the sharded parameter is not all-gathered again, which may eventually lead to a RuntimeError related to setStorage.

It may make sense to explicitly remove or invalidate hooks on the view tensor in this case. At a minimum, I would expect PyTorch to either forbid this usage or emit a warning, since it can take a while to realize that the pre-backward hook was silently dropped.

The script below can be used to reproduce this issue.

"""Minimal reproduction of the FSDP2 pre-backward hook bug:
When a downstream module performs an in-place op on a view of an FSDP2
module's output, the pre-backward hook registered by FSDP2 is silently
dropped. As a result:
  1. all-gather (unshard) is never triggered before backward.
  2. With reshard_after_forward=True, saved tensors point to a storage
     that has been resize_(0)'d, leading to:
        RuntimeError: setStorage: ... are out of bounds for storage of size 0

Run with (single node, 2 GPUs):
    torchrun --nproc_per_node=2 repro_fsdp2_inplace_hook_bug.py
Or single GPU:
    torchrun --nproc_per_node=1 repro_fsdp2_inplace_hook_bug.py

Toggle USE_INPLACE / USE_VIEW / RESHARD_AFTER_FORWARD at the top to
explore the four cases.
"""
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, FSDPModule

# ---------------------------------------------------------------------------
# Toggles to explore behavior:
USE_INPLACE = True            # True: out += y (triggers bug);  False: out = out + y
USE_VIEW    = True            # True: squeeze (view, triggers CopySlices); False: clone (no view)
RESHARD_AFTER_FORWARD = True  # True: storage gets resize_(0)'d; False: param stays unsharded
# ---------------------------------------------------------------------------


class _MatmulSaveView(torch.autograd.Function):
    """Custom matmul whose backward needs the EXACT saved view of weight.

    Mimics `torch._grouped_mm`: in backward, reconstruct a strided view from
    the saved weight's storage. If the storage was resize_(0)'d by FSDP
    reshard, this fails with:
        RuntimeError: setStorage: ... out of bounds for storage of size 0
    """
    @staticmethod
    def forward(ctx, x, w_view):
        ctx.save_for_backward(x, w_view)
        return x @ w_view

    @staticmethod
    def backward(ctx, grad_out):
        x, w_view = ctx.saved_tensors
        rank = int(os.environ.get("RANK", "0"))
        storage_nbytes = w_view.untyped_storage().nbytes()
        expected_nbytes = w_view.numel() * w_view.element_size()
        if rank == 0:
            print(f"[BWD] _MatmulSaveView.backward: w_view.shape={tuple(w_view.shape)}, "
                  f"storage.nbytes={storage_nbytes}, expected={expected_nbytes}, "
                  f"storage_freed={storage_nbytes == 0}", flush=True)
        # Mimic the strict TORCH_CHECK inside `_grouped_mm` backward: if FSDP
        # already resharded the parameter (storage resize_(0)), abort loudly.
        # PyTorch's bare `Tensor.set_` does NOT validate this and would silently
        # build a dangling view, hiding the bug behind garbage gradients.
        assert storage_nbytes >= expected_nbytes, (
            f"setStorage: sizes {tuple(w_view.shape)}, strides {w_view.stride()}, "
            f"requiring a storage size of {expected_nbytes} are out of bounds "
            f"for storage of size {storage_nbytes}"
        )
        # Force re-materialization via storage (same pattern as _grouped_mm)
        w_strided = torch.empty(0, dtype=w_view.dtype, device=w_view.device)
        w_strided.set_(w_view.untyped_storage(), 0, w_view.shape, w_view.stride())
        grad_x = grad_out @ w_strided.t()
        grad_w = x.t() @ grad_out
        return grad_x, grad_w


class InnerModule(nn.Module):
    """FSDP-wrapped leaf module. Its forward output is what the bug hits."""
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(dim, dim) / (dim ** 0.5))
        self.dim = dim

    def forward(self, x):
        out = _MatmulSaveView.apply(x, self.weight)
        if USE_VIEW:
            out = out.unsqueeze(1).squeeze(1)   # view → triggers CopySlices when in-place hits
        else:
            out = out.clone()                    # break view chain
        return out


class OuterModule(nn.Module):
    """Mimics Qwen3_5MoeSparseMoeBlock.forward: in-place += on FSDP output."""
    def __init__(self, dim):
        super().__init__()
        self.inner = InnerModule(dim)
        self.shared = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        expert_output = self.inner(x)          # FSDP-wrapped output, possibly a view
        shared_output = self.shared(x)
        if USE_INPLACE:
            expert_output += shared_output     # ⚠️ in-place on view → CopySlices
        else:
            expert_output = expert_output + shared_output
        return expert_output.sum()


def main():
    rank = int(os.environ.get("RANK", "0"))
    world_size = int(os.environ.get("WORLD_SIZE", "1"))
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)
        device = torch.device(f"cuda:{local_rank}")
    else:
        device = torch.device("cpu")

    dim = 4096
    model = OuterModule(dim).to(device)

    # Apply FSDP2 only to the inner module (the one whose output gets += into).
    fully_shard(model.inner, reshard_after_forward=RESHARD_AFTER_FORWARD)

    # ---------- Probe: register a tensor hook on inner output to see if it fires ----------
    # We patch InnerModule.forward to register a hook on its return value.
    orig_inner_forward = InnerModule.forward

    def probed_inner_forward(self, x):
        out = orig_inner_forward(self, x)
        if rank == 0:
            print(f"[probe] inner output: requires_grad={out.requires_grad}, "
                  f"is_view={out.is_view() if hasattr(out, 'is_view') else out._base is not None}, "
                  f"grad_fn={out.grad_fn}", flush=True)

        def _user_hook(grad):
            if rank == 0:
                print("[probe] >>> tensor.register_hook on inner output FIRED <<<", flush=True)
            return grad
        if out.requires_grad:
            out.register_hook(_user_hook)
        return out

    InnerModule.forward = probed_inner_forward

    # ---------- Run forward + backward ----------
    x = torch.randn(8, dim, device=device, requires_grad=False)
    if rank == 0:
        print(f"\n=== Config ===")
        print(f"  USE_INPLACE           = {USE_INPLACE}")
        print(f"  USE_VIEW              = {USE_VIEW}")
        print(f"  RESHARD_AFTER_FORWARD = {RESHARD_AFTER_FORWARD}")
        print(f"==============\n", flush=True)

    try:
        loss = model(x)
        if rank == 0:
            print(f"[main] forward done, loss={loss.item():.4f}", flush=True)
        loss.backward()
        if rank == 0:
            print(f"[main] backward SUCCEEDED", flush=True)
    except RuntimeError as e:
        if rank == 0:
            print(f"[main] backward FAILED: {e}", flush=True)

    dist.destroy_process_group()


if __name__ == "__main__":
    main()

Versions

Collecting environment information... PyTorch version: 2.10.0+cu128 Is debug build: False CUDA used to build PyTorch: 12.8 ROCM used to build PyTorch: N/A

OS: Alibaba Cloud Linux 3.2104 U10 (OpenAnolis Edition) (x86_64) GCC version: (GCC) 10.2.1 20200825 (Alibaba 10.2.1-3.8 2.32) Clang version: Could not collect CMake version: version 3.26.5 Libc version: glibc-2.32

Python version: 3.10.13+gc (heads/release/3.10.13-inc_gc:866f61ca61, Oct 21 2025, 10:00:00) [GCC 13.3.1 20240611 (Red Hat 13.3.1-2)] (64-bit runtime) Python platform: Linux-5.10.134-18.al8.x86_64-x86_64-with-glibc2.32 Is CUDA available: True CUDA runtime version: 12.8.61 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA L20C GPU 1: NVIDIA L20C GPU 2: NVIDIA L20C GPU 3: NVIDIA L20C GPU 4: NVIDIA L20C GPU 5: NVIDIA L20C GPU 6: NVIDIA L20C GPU 7: NVIDIA L20C

Nvidia driver version: 580.82.07 cuDNN version: Probably one of the following: /usr/lib64/libcudnn.so.9.12.0 /usr/lib64/libcudnn_adv.so.9.12.0 /usr/lib64/libcudnn_cnn.so.9.12.0 /usr/lib64/libcudnn_engines_precompiled.so.9.12.0 /usr/lib64/libcudnn_engines_runtime_compiled.so.9.12.0 /usr/lib64/libcudnn_graph.so.9.12.0 /usr/lib64/libcudnn_heuristic.so.9.12.0 /usr/lib64/libcudnn_ops.so.9.12.0 Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 224 On-line CPU(s) list: 0-79,112-191 Off-line CPU(s) list: 80-111,192-223 Thread(s) per core: 1 Core(s) per socket: 56 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 207 Model name: INTEL(R) XEON(R) PLATINUM 8581C CPU @ 2.10GHz Stepping: 2 CPU MHz: 2100.000 BogoMIPS: 4200.00 Hypervisor vendor: KVM Virtualization type: full L1d cache: 48K L1i cache: 32K L2 cache: 2048K L3 cache: 266240K NUMA node0 CPU(s): 0-55,112-167 NUMA node1 CPU(s): 56-111,168-223 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities

Versions of relevant libraries: [pip3] intel-cmplr-lib-ur==2025.2.1 [pip3] intel-openmp==2025.2.1 [pip3] mkl==2025.0.1 [pip3] mkl-include==2025.0.1 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.8.4.1 [pip3] nvidia-cuda-cupti-cu12==12.8.90 [pip3] nvidia-cuda-nvrtc-cu12==12.8.93 [pip3] nvidia-cuda-runtime-cu12==12.8.90 [pip3] nvidia-cudnn-cu12==9.10.2.21 [pip3] nvidia-cudnn-frontend==1.15.0 [pip3] nvidia-cufft-cu12==11.3.3.83 [pip3] nvidia-curand-cu12==10.3.9.90 [pip3] nvidia-cusolver-cu12==11.7.3.90 [pip3] nvidia-cusparse-cu12==12.5.8.93 [pip3] nvidia-cusparselt-cu12==0.7.1 [pip3] nvidia-nccl-cu12==2.27.5 [pip3] nvidia-nvjitlink-cu12==12.8.93 [pip3] nvidia-nvtx-cu12==12.8.90 [pip3] onnx==1.19.1 [pip3] onnx-ir==0.1.11 [pip3] onnxruntime==1.23.2 [pip3] onnxscript==0.3.1 [pip3] tbb==2022.2.0 [pip3] tcmlib==1.4.0 [pip3] torch==2.10.0 [pip3] torchaudio==2.10.0 [pip3] torchdata==0.11.0 [pip3] torchmetrics==1.8.2 [pip3] torchvision==0.25.0 [pip3] transformer_engine_mdl==2.6.0+torch2.8.bf78e5d1 [pip3] triton==3.6.0 [pip3] umf==0.11.0 [conda] intel-cmplr-lib-ur 2025.2.1 pypi_0 pypi [conda] intel-openmp 2025.2.1 pypi_0 pypi [conda] magma-cuda126 2.6.1 1 pytorch [conda] mkl 2025.0.1 pypi_0 pypi [conda] mkl-include 2025.0.1 pypi_0 pypi [conda] numpy 1.26.4 pypi_0 pypi [conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi [conda] nvidia-cudnn-frontend 1.15.0 pypi_0 pypi [conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi [conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi [conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi [conda] tbb 2022.2.0 pypi_0 pypi [conda] tbb-devel 2022.0.0 hdb19cb5_0
[conda] tcmlib 1.4.0 pypi_0 pypi [conda] torch 2.10.0 pypi_0 pypi [conda] torchaudio 2.10.0 pypi_0 pypi [conda] torchdata 0.11.0 pypi_0 pypi [conda] torchmetrics 1.8.2 pypi_0 pypi [conda] torchvision 0.25.0 pypi_0 pypi [conda] transformer-engine-mdl 2.6.0+torch2.8.bf78e5d1 pypi_0 pypi [conda] triton 3.6.0 pypi_0 pypi [conda] umf 0.11.0 pypi_0 pypi

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @weifengpy @zhaojuanmao @mrshenli @rohan-varma @chauhang @mori360 @ppwwyyxx

extent analysis

TL;DR

The issue can be mitigated by explicitly removing or invalidating hooks on the view tensor when reshard_after_forward is enabled and an in-place operation is performed on the view tensor.

Guidance

  • Identify the conditions under which the hook is lost: reshard_after_forward is enabled, an in-place operation is performed on the view tensor, and a pre-backward hook is registered on the view tensor.
  • Consider explicitly removing or invalidating hooks on the view tensor when these conditions are met to prevent the hook from being silently dropped.
  • Verify that the pre-backward hook is being called during backward by adding print statements or debugging statements in the hook function.
  • If the issue persists, try disabling reshard_after_forward or avoiding in-place operations on the view tensor to see if the hook is preserved.

Example

def probed_inner_forward(self, x):
    out = orig_inner_forward(self, x)
    # ...
    if reshard_after_forward and USE_INPLACE:
        # Explicitly remove or invalidate the hook on the view tensor
        out._hooks.clear()
    return out

Notes

The provided code snippet is a minimal reproduction of the issue, and the exact solution may depend on the specific use case and requirements. Additionally, the issue may be related to the implementation of fully_shard and FSDPModule in PyTorch, and upgrading to a newer version of PyTorch may resolve the issue.

Recommendation

Apply a workaround by explicitly removing or invalidating hooks on the view tensor when reshard_after_forward is enabled and an in-place operation is performed on the view tensor, as this is the most direct way to address the issue based on the provided information.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [FSDP2] Inplace operation on a view tensor can drop the fsdp pre-backward hook on it. [1 pull requests, 1 comments, 2 participants]