pytorch - 💡(How to fix) Fix Incompatibility between FSDP and nn.utils.parametrize [9 comments, 5 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177017Fetched 2026-04-08 00:22:48
View on GitHub
Comments
9
Participants
5
Timeline
45
Reactions
0
Timeline (top)
mentioned ×16subscribed ×16commented ×9labeled ×2

Error Message

[rank1]: result = x @ self.my_param [rank1]: ~~^~~~~~~~~~~~~~~ [rank1]: RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

Root Cause

I have a use case that requires using nn.utils.parametrize dynamically during the forward call. When combining this with FSDP, it leads to an error where a parametrize check fails because the parameter is present on the module when it shouldn't be, triggered here:

Fix Action

Fix / Workaround

def patch_forward(original_forward): # Monkey-patch SubModule.forward to call register_parametrization during forward

def patched_forward(self, x): nn.utils.parametrize.register_parametrization(self, "my_param", IdentityParametrization()) result = x @ self.my_param nn.utils.parametrize.remove_parametrizations(self, "my_param", leave_parametrized=False) return result

SubModule.forward = patched_forward

Code Example

import os
import torch
import torch.nn as nn


class SubModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.my_param = nn.Parameter(torch.randn(1024, 1024))

    def forward(self, x):
        return x @ self.my_param


class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub = SubModule()
        self.linear = nn.Linear(1024, 1024)

    def forward(self, x):
        return self.linear(self.sub(x))


class Model(nn.Module):
    def __init__(self, num_layers=4):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock() for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class IdentityParametrization(nn.Module):
    def forward(self, x):
        return x


def patch_forward(original_forward):
    # Monkey-patch SubModule.forward to call register_parametrization during forward

    def patched_forward(self, x):
        nn.utils.parametrize.register_parametrization(self, "my_param", IdentityParametrization())
        result = x @ self.my_param
        nn.utils.parametrize.remove_parametrizations(self, "my_param", leave_parametrized=False)
        return result

    SubModule.forward = patched_forward


def main():
    use_fsdp = "WORLD_SIZE" in os.environ
    if use_fsdp:
        import torch.distributed as dist

        dist.init_process_group("nccl")
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        from functools import partial

        rank = dist.get_rank()
        torch.cuda.set_device(rank)
        device = f"cuda:{rank}"
    else:
        device = "cuda:0"

    model = Model().to(device)

    if use_fsdp:
        auto_wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={TransformerBlock},
        )
        model = FSDP(model, auto_wrap_policy=auto_wrap_policy, use_orig_params=True)

    x = torch.randn(2, 1024, device=device)
    model(x)

    original_forward = SubModule.forward
    patch_forward(original_forward)

    # With FSDP this fails with: AssertionError
    #   File .../torch/nn/utils/parametrize.py", line 387, in _inject_property
    #     assert not hasattr(module, tensor_name)
    model(x)
    print("passes")
    if use_fsdp:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

---

[rank1]:     result = x @ self.my_param
[rank1]:              ~~^~~~~~~~~~~~~~~
[rank1]: RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

I have a use case that requires using nn.utils.parametrize dynamically during the forward call. When combining this with FSDP, it leads to an error where a parametrize check fails because the parameter is present on the module when it shouldn't be, triggered here:

https://github.com/pytorch/pytorch/blob/e45dfba36d6b8b5efdf5a99f943a655be07991c2/torch/nn/utils/parametrize.py#L406-L407

Below is a reproducer, it passes when calling it with python test.py but fails with torchrun --nproc_per_node=2 test.py:

import os
import torch
import torch.nn as nn


class SubModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.my_param = nn.Parameter(torch.randn(1024, 1024))

    def forward(self, x):
        return x @ self.my_param


class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.sub = SubModule()
        self.linear = nn.Linear(1024, 1024)

    def forward(self, x):
        return self.linear(self.sub(x))


class Model(nn.Module):
    def __init__(self, num_layers=4):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock() for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class IdentityParametrization(nn.Module):
    def forward(self, x):
        return x


def patch_forward(original_forward):
    # Monkey-patch SubModule.forward to call register_parametrization during forward

    def patched_forward(self, x):
        nn.utils.parametrize.register_parametrization(self, "my_param", IdentityParametrization())
        result = x @ self.my_param
        nn.utils.parametrize.remove_parametrizations(self, "my_param", leave_parametrized=False)
        return result

    SubModule.forward = patched_forward


def main():
    use_fsdp = "WORLD_SIZE" in os.environ
    if use_fsdp:
        import torch.distributed as dist

        dist.init_process_group("nccl")
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
        from functools import partial

        rank = dist.get_rank()
        torch.cuda.set_device(rank)
        device = f"cuda:{rank}"
    else:
        device = "cuda:0"

    model = Model().to(device)

    if use_fsdp:
        auto_wrap_policy = partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={TransformerBlock},
        )
        model = FSDP(model, auto_wrap_policy=auto_wrap_policy, use_orig_params=True)

    x = torch.randn(2, 1024, device=device)
    model(x)

    original_forward = SubModule.forward
    patch_forward(original_forward)

    # With FSDP this fails with: AssertionError
    #   File .../torch/nn/utils/parametrize.py", line 387, in _inject_property
    #     assert not hasattr(module, tensor_name)
    model(x)
    print("passes")
    if use_fsdp:
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

The problem seems to be that FSDP sets the attribute on the module's __dict__ but register_parametrization doesn't clean that up. I could make some progress by adding module.__dict__.pop(tensor_name, None) here:

https://github.com/pytorch/pytorch/blob/e45dfba36d6b8b5efdf5a99f943a655be07991c2/torch/nn/utils/parametrize.py#L649

But even after that, I run into an error:

[rank1]:     result = x @ self.my_param
[rank1]:              ~~^~~~~~~~~~~~~~~
[rank1]: RuntimeError: Output 0 of ViewBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

So I think it's a more fundamental problem between FSDP and nn.utils.parametrize.

Versions

Tested with torch 2.9.1 and 2.10.0.

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan

extent analysis

Fix Plan

1. Modify nn.utils.parametrize.register_parametrization to handle FSDP

Add the following code to torch/nn/utils/parametrize.py:

def register_parametrization(module, tensor_name, parametrization):
    # Check if the module is a FSDP module
    if hasattr(module, 'fsdp_module'):
        # Remove the attribute from the FSDP module's __dict__
        module.fsdp_module.__dict__.pop(tensor_name, None)
    # Register the parametrization as usual
    _inject_property(module, tensor_name, parametrization)

2. Modify nn.utils.parametrize.remove_parametrizations to handle FSDP

Add the following code to torch/nn/utils/parametrize.py:

def remove_parametrizations(module, tensor_name, leave_parametrized):
    # Check if the module is a FSDP module
    if hasattr(module, 'fsdp_module'):
        # Remove the attribute from the FSDP module's __dict__
        module.fsdp_module.__dict__.pop(tensor_name, None)
    # Remove the parametrization as usual
    _remove_property(module, tensor_name, leave_parametrized)

3. Update the reproducer to use the modified nn.utils.parametrize

Update the patch_forward function in the reproducer to use the modified nn.utils.parametrize:

def patch_forward(original_forward):
    # Monkey-patch SubModule.forward to call register_parametrization during forward

    def patched_forward(self, x):
        nn.utils.parametrize.register_parametrization(self, "my_param", IdentityParametrization())
        result = x @ self.my_param
        nn.utils.parametrize.remove_parametrizations(self, "my_param", leave_parametrized=False)
        return result

    SubModule.forward = patched_forward

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix Incompatibility between FSDP and nn.utils.parametrize [9 comments, 5 participants]