pytorch - ✅(Solved) Fix torch.cuda.make_graphed_callables can silently corrupt parameter gradient accumulation by returning static grad buffers [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181723Fetched 2026-04-29 06:11:14
View on GitHub
Comments
0
Participants
1
Timeline
57
Reactions
0
Author
Participants
Timeline (top)
mentioned ×26subscribed ×26labeled ×4cross-referenced ×1

Fix Action

Fixed

PR fix notes

PR #2937: Fix CUDA graph parameter grad lifetime

Description (problem / solution / changelog)

Summary

Fix CUDA graph replay so parameter gradients returned from Graphed.backward do not expose CUDA graph static buffers to downstream autograd users.

The fix clones returned parameter gradients before handing them back to autograd, while preserving the existing aliasing behavior for delayed-wgrad parameters marked with skip_backward_post_hook.

Root Cause

When CUDA graph replay returns parameter grad slots directly from static graph buffers, downstream autograd users can retain references to buffers that are overwritten by later graph replays. This can corrupt retained grads or break gradient accumulation semantics.

This is related to PyTorch issue https://github.com/pytorch/pytorch/issues/181723.

Changes

  • Detect parameter grad slots in the graphed autograd input surface.
  • Clone returned non-delayed-wgrad parameter grads before returning from Graphed.backward.
  • Allow reused graph input/output buffer mode to weak-ref current parameter grad static buffers after capture because returned grads are now cloned.
  • Add CUDA graph tests for owned returned parameter grads, accumulation, delayed-wgrad alias preservation, and reused buffer interleaved pipeline replay.

Changed files

  • tests/pytorch/test_cuda_graphs.py (modified, +208/-6)
  • transformer_engine/pytorch/graph.py (modified, +55/-4)

Code Example

import torch


class ScaleByWeight(torch.nn.Module):
    def __init__(self, width: int) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(
            torch.arange(1, width + 1, device="cuda", dtype=torch.float32)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.weight


torch.cuda.set_device(0)
torch.manual_seed(1234)

width = 4
model = ScaleByWeight(width).cuda()
sample_x = torch.ones(2, width, device="cuda", dtype=torch.float32)
graphed_model = torch.cuda.make_graphed_callables(model, (sample_x,))

x1 = torch.tensor(
    [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]],
    device="cuda",
)
x2 = torch.tensor(
    [[10.0, 20.0, 30.0, 40.0], [1.0, 2.0, 3.0, 4.0]],
    device="cuda",
)

expected_g1 = x1.sum(dim=0)
expected_g2 = x2.sum(dim=0)
expected_total = expected_g1 + expected_g2
bug_total = expected_g2 + expected_g2

model.weight.grad = None

graphed_model(x1).sum().backward()
torch.cuda.synchronize()

first_grad_ref = model.weight.grad
first_grad_ptr = first_grad_ref.data_ptr()
first_grad_snapshot = first_grad_ref.detach().clone()

observations = {}


def capture_pre_accumulate(incoming_grad: torch.Tensor) -> torch.Tensor:
    existing_grad = model.weight.grad
    observations["incoming_ptr"] = incoming_grad.data_ptr()
    observations["incoming_value"] = incoming_grad.detach().clone()
    observations["existing_ptr"] = None if existing_grad is None else existing_grad.data_ptr()
    observations["existing_value"] = (
        None if existing_grad is None else existing_grad.detach().clone()
    )
    return incoming_grad


hook_handle = model.weight.register_hook(capture_pre_accumulate)
try:
    graphed_model(x2).sum().backward()
    torch.cuda.synchronize()
finally:
    hook_handle.remove()

final_grad = model.weight.grad.detach().clone()

print("expected_g1:", expected_g1.cpu().tolist())
print("expected_g2:", expected_g2.cpu().tolist())
print("expected_total:", expected_total.cpu().tolist())
print("bug_total:", bug_total.cpu().tolist())

print("first_grad_ptr:", first_grad_ptr)
print("first_grad_snapshot:", first_grad_snapshot.cpu().tolist())
print("pre_accum_existing_ptr:", observations["existing_ptr"])
print("pre_accum_incoming_ptr:", observations["incoming_ptr"])
print("pre_accum_existing_value:", observations["existing_value"].cpu().tolist())
print("pre_accum_incoming_value:", observations["incoming_value"].cpu().tolist())
print("final_grad:", final_grad.cpu().tolist())

print("first grad == g1:", torch.equal(first_grad_snapshot, expected_g1))
print("pre-accumulate existing grad == g1:", torch.equal(observations["existing_value"], expected_g1))
print("pre-accumulate existing grad == g2:", torch.equal(observations["existing_value"], expected_g2))
print("incoming grad == g2:", torch.equal(observations["incoming_value"], expected_g2))
print("final grad == expected g1+g2:", torch.equal(final_grad, expected_total))
print("final grad == bug g2+g2:", torch.equal(final_grad, bug_total))

---

expected_g1: [3.0, 5.0, 7.0, 9.0]
expected_g2: [11.0, 22.0, 33.0, 44.0]
expected_total: [14.0, 27.0, 40.0, 53.0]
bug_total: [22.0, 44.0, 66.0, 88.0]

first_grad_ptr: 133201995498496
first_grad_snapshot: [3.0, 5.0, 7.0, 9.0]
pre_accum_existing_ptr: 133201995498496
pre_accum_incoming_ptr: 133201995498496
pre_accum_existing_value: [11.0, 22.0, 33.0, 44.0]
pre_accum_incoming_value: [11.0, 22.0, 33.0, 44.0]
final_grad: [22.0, 44.0, 66.0, 88.0]

first grad == g1: True
pre-accumulate existing grad == g1: False
pre-accumulate existing grad == g2: True
incoming grad == g2: True
final grad == expected g1+g2: False
final grad == bug g2+g2: True

---

final_grad == [14.0, 27.0, 40.0, 53.0]

---

final_grad == [22.0, 44.0, 66.0, 88.0]

---

PyTorch Version: 2.11.0a0+a6c236b9fd
CUDA available: True
Torch CUDA version: 13.2
RAW_BUFFERClick to expand / collapse

Title

torch.cuda.make_graphed_callables can silently corrupt parameter gradient accumulation by returning static grad buffers

Body

Bug description

torch.cuda.make_graphed_callables can silently produce incorrect parameter gradients when a graphed module is used with gradient accumulation over multiple microbatches.

The backward autograd function returns detached views of captured static_grad_inputs. For module parameters, those returned tensors flow into AccumulateGrad. If param.grad is None, AccumulateGrad may install/steal the incoming grad tensor as param.grad. In that case, param.grad aliases the CUDA graph's static parameter grad buffer. The next backward graph replay overwrites the same static buffer before AccumulateGrad accumulates the next microbatch's gradient.

This makes the accumulated gradient become g2 + g2 instead of g1 + g2.

This is silent numerical corruption.

Minimal repro

import torch


class ScaleByWeight(torch.nn.Module):
    def __init__(self, width: int) -> None:
        super().__init__()
        self.weight = torch.nn.Parameter(
            torch.arange(1, width + 1, device="cuda", dtype=torch.float32)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.weight


torch.cuda.set_device(0)
torch.manual_seed(1234)

width = 4
model = ScaleByWeight(width).cuda()
sample_x = torch.ones(2, width, device="cuda", dtype=torch.float32)
graphed_model = torch.cuda.make_graphed_callables(model, (sample_x,))

x1 = torch.tensor(
    [[1.0, 2.0, 3.0, 4.0], [2.0, 3.0, 4.0, 5.0]],
    device="cuda",
)
x2 = torch.tensor(
    [[10.0, 20.0, 30.0, 40.0], [1.0, 2.0, 3.0, 4.0]],
    device="cuda",
)

expected_g1 = x1.sum(dim=0)
expected_g2 = x2.sum(dim=0)
expected_total = expected_g1 + expected_g2
bug_total = expected_g2 + expected_g2

model.weight.grad = None

graphed_model(x1).sum().backward()
torch.cuda.synchronize()

first_grad_ref = model.weight.grad
first_grad_ptr = first_grad_ref.data_ptr()
first_grad_snapshot = first_grad_ref.detach().clone()

observations = {}


def capture_pre_accumulate(incoming_grad: torch.Tensor) -> torch.Tensor:
    existing_grad = model.weight.grad
    observations["incoming_ptr"] = incoming_grad.data_ptr()
    observations["incoming_value"] = incoming_grad.detach().clone()
    observations["existing_ptr"] = None if existing_grad is None else existing_grad.data_ptr()
    observations["existing_value"] = (
        None if existing_grad is None else existing_grad.detach().clone()
    )
    return incoming_grad


hook_handle = model.weight.register_hook(capture_pre_accumulate)
try:
    graphed_model(x2).sum().backward()
    torch.cuda.synchronize()
finally:
    hook_handle.remove()

final_grad = model.weight.grad.detach().clone()

print("expected_g1:", expected_g1.cpu().tolist())
print("expected_g2:", expected_g2.cpu().tolist())
print("expected_total:", expected_total.cpu().tolist())
print("bug_total:", bug_total.cpu().tolist())

print("first_grad_ptr:", first_grad_ptr)
print("first_grad_snapshot:", first_grad_snapshot.cpu().tolist())
print("pre_accum_existing_ptr:", observations["existing_ptr"])
print("pre_accum_incoming_ptr:", observations["incoming_ptr"])
print("pre_accum_existing_value:", observations["existing_value"].cpu().tolist())
print("pre_accum_incoming_value:", observations["incoming_value"].cpu().tolist())
print("final_grad:", final_grad.cpu().tolist())

print("first grad == g1:", torch.equal(first_grad_snapshot, expected_g1))
print("pre-accumulate existing grad == g1:", torch.equal(observations["existing_value"], expected_g1))
print("pre-accumulate existing grad == g2:", torch.equal(observations["existing_value"], expected_g2))
print("incoming grad == g2:", torch.equal(observations["incoming_value"], expected_g2))
print("final grad == expected g1+g2:", torch.equal(final_grad, expected_total))
print("final grad == bug g2+g2:", torch.equal(final_grad, bug_total))

Actual output

expected_g1: [3.0, 5.0, 7.0, 9.0]
expected_g2: [11.0, 22.0, 33.0, 44.0]
expected_total: [14.0, 27.0, 40.0, 53.0]
bug_total: [22.0, 44.0, 66.0, 88.0]

first_grad_ptr: 133201995498496
first_grad_snapshot: [3.0, 5.0, 7.0, 9.0]
pre_accum_existing_ptr: 133201995498496
pre_accum_incoming_ptr: 133201995498496
pre_accum_existing_value: [11.0, 22.0, 33.0, 44.0]
pre_accum_incoming_value: [11.0, 22.0, 33.0, 44.0]
final_grad: [22.0, 44.0, 66.0, 88.0]

first grad == g1: True
pre-accumulate existing grad == g1: False
pre-accumulate existing grad == g2: True
incoming grad == g2: True
final grad == expected g1+g2: False
final grad == bug g2+g2: True

Expected behavior

The second backward should accumulate into the first microbatch gradient:

final_grad == [14.0, 27.0, 40.0, 53.0]

Instead, the first gradient is overwritten by the second graph replay before accumulation, and the final result is:

final_grad == [22.0, 44.0, 66.0, 88.0]

Environment

Observed on a CUDA-enabled PyTorch build:

PyTorch Version: 2.11.0a0+a6c236b9fd
CUDA available: True
Torch CUDA version: 13.2

Possible fix direction

For module parameter gradient slots, make_graphed_callables should not return detached aliases of captured static grad buffers to autograd. Returning an owned tensor, for example cloning parameter grad slots in Graphed.backward(), would avoid letting AccumulateGrad install a static CUDA graph buffer as param.grad.

The current behavior appears inconsistent with the documented drop-in replacement expectation for autograd-enabled training loops.

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia

extent analysis

TL;DR

The most likely fix for the silent corruption of parameter gradients when using torch.cuda.make_graphed_callables with gradient accumulation is to modify make_graphed_callables to return owned tensors instead of detached aliases of captured static grad buffers.

Guidance

  • Identify the root cause: The issue arises from make_graphed_callables returning detached views of captured static_grad_inputs, which can lead to AccumulateGrad installing a static CUDA graph buffer as param.grad.
  • Verify the issue: Run the provided minimal repro code to observe the incorrect gradient accumulation.
  • Mitigate the issue: Consider cloning parameter grad slots in Graphed.backward() to return owned tensors instead of detached aliases.
  • Investigate alternatives: Explore other possible fixes, such as modifying AccumulateGrad to handle static CUDA graph buffers correctly.

Example

# Example of cloning parameter grad slots in Graphed.backward()
def Graphed_backward(self, inputs, output_grads):
    # ...
    param_grads = []
    for param in self.module.parameters():
        if param.grad is not None:
            # Clone the grad slot to return an owned tensor
            param_grads.append(param.grad.clone())
        else:
            param_grads.append(torch.zeros_like(param))
    # ...
    return param_grads

Notes

The provided fix direction is based on the assumption that returning owned tensors will avoid the silent corruption of parameter gradients. However, this may require further investigation and testing to ensure correctness.

Recommendation

Apply the workaround of cloning parameter grad slots in Graphed.backward() to return owned tensors, as this appears to be the most straightforward fix for the identified issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

The second backward should accumulate into the first microbatch gradient:

final_grad == [14.0, 27.0, 40.0, 53.0]

Instead, the first gradient is overwritten by the second graph replay before accumulation, and the final result is:

final_grad == [22.0, 44.0, 66.0, 88.0]

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix torch.cuda.make_graphed_callables can silently corrupt parameter gradient accumulation by returning static grad buffers [1 pull requests, 1 participants]