pytorch - 💡(How to fix) Fix [pipelining] dW split path double-counts grad for shared weights

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

At the forward level, w appears at two matmul sites, which in the autograd graph become two separate MmBackward nodes that share the same AccumulateGrad node for w:

  • MmBackward_1 — backward for x @ w (forward output: h0)
  • MmBackward_2 — backward for h0 @ w (forward output: h1)

get_param_groups correctly collects both matmul nodes into param_group["intermediates"] for w, and stage_backward_input registers a prehook at each to save the incoming gradient (dL/dh0 and dL/dh1 respectively).

The problem is in stage_backward_weight at the call:

https://github.com/pytorch/pytorch/blob/0cfb7e25c47f2d63a71c475b594b5252be5b0ddb/torch/distributed/pipelining/_backward.py#L294-L299

It invokes:

torch.autograd.grad(
    outputs=valid_edges,             # GradientEdge(MmBackward_1), GradientEdge(MmBackward_2)
    inputs=(GradientEdge(AccGrad_w, 0),),
    grad_outputs=valid_grad_outputs, # saved dL/dh0, saved dL/dh1
)

Because MmBackward_2's forward output h1 is a descendant of MmBackward_1's forward output h0, starting backprop from MmBackward_2 with grad dL/dh1 reaches w via two routes:

  1. Directly through MmBackward_2's weight input → contributes h0ᵀ @ dL/dh1 to dw. ✓
  2. Through MmBackward_2's input h0 → then through MmBackward_1 → contributes xᵀ @ (dL/dh1 @ wᵀ) to dw.

But the saved dL/dh0 (captured by the hook at MmBackward_1 during stage_backward_input) already equals dL/dh1 @ wᵀ in this graph, so starting backprop from MmBackward_1 with this saved grad adds the same contribution xᵀ @ dL/dh0 again.

Net effect: the contribution from MmBackward_1 is counted twice. Numerically:

matmul_1_contrib = xᵀ @ dL/dh0 = [[0.22, 1.10], [-0.14, -0.70]]
matmul_2_contrib = h0ᵀ @ dL/dh1 = [[-0.05, -0.05], [-0.43, -0.43]]

baseline        = matmul_1_contrib + matmul_2_contrib        = [[0.17, 1.05], [-0.57, -1.13]]
split (buggy)   = 2 * matmul_1_contrib + matmul_2_contrib    = [[0.39, 2.15], [-0.71, -1.83]]

Both match the printed output.

Code Example

import torch
import torch.nn as nn
from torch.distributed.pipelining._backward import (
    stage_backward_input,
    stage_backward_weight,
)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(
            torch.tensor([[0.4, -0.2], [0.7, 0.3]], dtype=torch.float32)
        )

    def forward(self, x):
        return (x @ self.w @ self.w).sum()   # w used at TWO matmul sites


# baseline
net_ref = Net()
x_ref = torch.tensor([[1.1, -0.7]], dtype=torch.float32, requires_grad=True)
out_ref = net_ref(x_ref)
dx_ref, dw_ref = torch.autograd.grad(out_ref, [x_ref, net_ref.w])

# split path
net_split = Net()
x_split = torch.tensor([[1.1, -0.7]], dtype=torch.float32, requires_grad=True)
out_split = net_split(x_split)
weights = [net_split.w]
dinputs, param_groups = stage_backward_input(
    stage_outputs_or_loss=[out_split],
    output_grads=None,
    input_values=[x_split],
    weights=iter(weights),
)
stage_backward_weight(weights=iter(weights), param_groups=param_groups)

print("dx_ref  :", dx_ref)
print("dx_split:", dinputs[0])
print("dw_ref  :", dw_ref)
print("dw_split:", net_split.w.grad)

---

dx_ref  : tensor([[-0.1200,  0.4400]])
dx_split: tensor([[-0.1200,  0.4400]])
dw_ref  : tensor([[ 0.1700,  1.0500],
                  [-0.5700, -1.1300]])
dw_split: tensor([[ 0.3900,  2.1500],
                  [-0.7100, -1.8300]])            <-- WRONG

---

https://github.com/pytorch/pytorch/blob/0cfb7e25c47f2d63a71c475b594b5252be5b0ddb/torch/distributed/pipelining/_backward.py#L294-L299

---

torch.autograd.grad(
    outputs=valid_edges,             # GradientEdge(MmBackward_1), GradientEdge(MmBackward_2)
    inputs=(GradientEdge(AccGrad_w, 0),),
    grad_outputs=valid_grad_outputs, # saved dL/dh0, saved dL/dh1
)

---

matmul_1_contrib = xᵀ @ dL/dh0 = [[0.22, 1.10], [-0.14, -0.70]]
matmul_2_contrib = h0ᵀ @ dL/dh1 = [[-0.05, -0.05], [-0.43, -0.43]]

baseline        = matmul_1_contrib + matmul_2_contrib        = [[0.17, 1.05], [-0.57, -1.13]]
split (buggy)   = 2 * matmul_1_contrib + matmul_2_contrib    = [[0.39, 2.15], [-0.71, -1.83]]

---

PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
Python version: 3.10.12 (64-bit runtime)
Python platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.35
Is CUDA available: False
Is XPU available: False

[pip3] numpy==2.2.6
[pip3] torch==2.10.0+cpu
[pip3] triton==3.6.0

Note: bug reproduces on CPU-only and is independent of device. Reproduced against
current main (SHA 0cfb7e25c47f2d63a71c475b594b5252be5b0ddb).
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.distributed.pipelining._backward.stage_backward_input / stage_backward_weight (the dW/dX-split path used by pipeline schedules that separate input-gradient and weight-gradient passes, e.g. ZeroBubble / 1F1B-V) produces incorrect weight gradients when a parameter is used at multiple sites in the forward graph AND those use sites lie on a parent-child path.

Minimal reproducer

import torch
import torch.nn as nn
from torch.distributed.pipelining._backward import (
    stage_backward_input,
    stage_backward_weight,
)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(
            torch.tensor([[0.4, -0.2], [0.7, 0.3]], dtype=torch.float32)
        )

    def forward(self, x):
        return (x @ self.w @ self.w).sum()   # w used at TWO matmul sites


# baseline
net_ref = Net()
x_ref = torch.tensor([[1.1, -0.7]], dtype=torch.float32, requires_grad=True)
out_ref = net_ref(x_ref)
dx_ref, dw_ref = torch.autograd.grad(out_ref, [x_ref, net_ref.w])

# split path
net_split = Net()
x_split = torch.tensor([[1.1, -0.7]], dtype=torch.float32, requires_grad=True)
out_split = net_split(x_split)
weights = [net_split.w]
dinputs, param_groups = stage_backward_input(
    stage_outputs_or_loss=[out_split],
    output_grads=None,
    input_values=[x_split],
    weights=iter(weights),
)
stage_backward_weight(weights=iter(weights), param_groups=param_groups)

print("dx_ref  :", dx_ref)
print("dx_split:", dinputs[0])
print("dw_ref  :", dw_ref)
print("dw_split:", net_split.w.grad)

Observed output

dx_ref  : tensor([[-0.1200,  0.4400]])
dx_split: tensor([[-0.1200,  0.4400]])
dw_ref  : tensor([[ 0.1700,  1.0500],
                  [-0.5700, -1.1300]])
dw_split: tensor([[ 0.3900,  2.1500],
                  [-0.7100, -1.8300]])            <-- WRONG

dx matches. dw is wrong with max |diff| = 1.10.

Expected behavior

dw_split should equal dw_ref (what torch.autograd.grad computes in one shot). The whole point of stage_backward_input + stage_backward_weight is to produce the same gradients as a regular backward, just split across two calls.

Root cause analysis

At the forward level, w appears at two matmul sites, which in the autograd graph become two separate MmBackward nodes that share the same AccumulateGrad node for w:

  • MmBackward_1 — backward for x @ w (forward output: h0)
  • MmBackward_2 — backward for h0 @ w (forward output: h1)

get_param_groups correctly collects both matmul nodes into param_group["intermediates"] for w, and stage_backward_input registers a prehook at each to save the incoming gradient (dL/dh0 and dL/dh1 respectively).

The problem is in stage_backward_weight at the call:

https://github.com/pytorch/pytorch/blob/0cfb7e25c47f2d63a71c475b594b5252be5b0ddb/torch/distributed/pipelining/_backward.py#L294-L299

It invokes:

torch.autograd.grad(
    outputs=valid_edges,             # GradientEdge(MmBackward_1), GradientEdge(MmBackward_2)
    inputs=(GradientEdge(AccGrad_w, 0),),
    grad_outputs=valid_grad_outputs, # saved dL/dh0, saved dL/dh1
)

Because MmBackward_2's forward output h1 is a descendant of MmBackward_1's forward output h0, starting backprop from MmBackward_2 with grad dL/dh1 reaches w via two routes:

  1. Directly through MmBackward_2's weight input → contributes h0ᵀ @ dL/dh1 to dw. ✓
  2. Through MmBackward_2's input h0 → then through MmBackward_1 → contributes xᵀ @ (dL/dh1 @ wᵀ) to dw.

But the saved dL/dh0 (captured by the hook at MmBackward_1 during stage_backward_input) already equals dL/dh1 @ wᵀ in this graph, so starting backprop from MmBackward_1 with this saved grad adds the same contribution xᵀ @ dL/dh0 again.

Net effect: the contribution from MmBackward_1 is counted twice. Numerically:

matmul_1_contrib = xᵀ @ dL/dh0 = [[0.22, 1.10], [-0.14, -0.70]]
matmul_2_contrib = h0ᵀ @ dL/dh1 = [[-0.05, -0.05], [-0.43, -0.43]]

baseline        = matmul_1_contrib + matmul_2_contrib        = [[0.17, 1.05], [-0.57, -1.13]]
split (buggy)   = 2 * matmul_1_contrib + matmul_2_contrib    = [[0.39, 2.15], [-0.71, -1.83]]

Both match the printed output.

Trigger conditions

Double-counting occurs iff both hold for some weight w:

  1. param_group["intermediates"] for w contains two nodes I_parent and I_child where I_child's forward output is a descendant of I_parent's forward output; and
  2. The backward path from I_child to AccGrad_w passes through I_parent (i.e. the same weight w participates in the subgraph between I_parent and I_child).

Condition (2) is what makes this a shared-weight issue rather than a general nested-intermediate issue.

Related code

Impact

Any pipeline schedule that uses the dW/dX-split backward (e.g. ZeroBubble variants and any custom schedule calling these helpers directly) will silently produce wrong weight gradients for models with tied/shared weights where the shared parameter is used at multiple forward sites on nested paths. Common real-world occurrences include:

  • Weight-tying between encoder input embedding and decoder output projection when both appear on the same micro-batch compute path.
  • Repeatedly-applied layers (weight sharing across iterations within one forward call).

Versions

PyTorch version: 2.10.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
Python version: 3.10.12 (64-bit runtime)
Python platform: Linux-6.8.0-106-generic-x86_64-with-glibc2.35
Is CUDA available: False
Is XPU available: False

[pip3] numpy==2.2.6
[pip3] torch==2.10.0+cpu
[pip3] triton==3.6.0

Note: bug reproduces on CPU-only and is independent of device. Reproduced against
current main (SHA 0cfb7e25c47f2d63a71c475b594b5252be5b0ddb).

cc @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @aditvenk @xmfan @weifengpy

extent analysis

TL;DR

To fix the incorrect weight gradients issue in torch.distributed.pipelining._backward.stage_backward_input and stage_backward_weight, modify the stage_backward_weight function to avoid double-counting contributions from nested intermediate nodes.

Guidance

  • Identify the nested intermediate nodes in the backward graph where the same weight is used, and ensure that their contributions are not counted twice.
  • Modify the stage_backward_weight function to keep track of the nodes that have already been visited and avoid recomputing their contributions.
  • Use the param_group["intermediates"] to detect the nested intermediate nodes and adjust the backward pass accordingly.
  • Verify the fix by comparing the weight gradients computed by stage_backward_weight with those computed by torch.autograd.grad.

Example

# Modified stage_backward_weight function
def stage_backward_weight(weights, param_groups):
    # ...
    visited_nodes = set()
    for node in param_groups["intermediates"]:
        if node not in visited_nodes:
            # Compute the contribution from this node
            contribution = torch.autograd.grad(
                outputs=node,
                inputs=(GradientEdge(AccGrad_w, 0),),
                grad_outputs=saved_grad_outputs,
            )
            # Mark the node as visited
            visited_nodes.add(node)
            # Accumulate the contribution
            dw += contribution
    # ...

Notes

The fix requires modifying the stage_backward_weight function to keep track of the visited nodes and avoid double-counting their contributions. This fix assumes that the param_group["intermediates"] contains the correct information about the nested intermediate nodes.

Recommendation

Apply the workaround by modifying the stage_backward_weight function to avoid double-counting contributions from nested intermediate nodes. This fix should be applied until the issue is resolved in the PyTorch library.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

dw_split should equal dw_ref (what torch.autograd.grad computes in one shot). The whole point of stage_backward_input + stage_backward_weight is to produce the same gradients as a regular backward, just split across two calls.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING