pytorch - ✅(Solved) Fix Interaction between Activation Checkpointing + make_fx based tracer [1 pull requests, 4 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178935Fetched 2026-04-08 01:56:58
View on GitHub
Comments
4
Participants
3
Timeline
51
Reactions
0
Timeline (top)
mentioned ×20subscribed ×20labeled ×5commented ×4

Error Message

  1. In non-strict case, warn that SAC is not guaranteed to work and recommend to use graph AC.

Root Cause

If you look at gelu_backward, vanilla AC uses recomputed mm while SAC doesn't. This is a bug because in SAC case, we fall into https://github.com/pytorch/pytorch/blob/6c67f765a2e3c6711720c74e9fb3d34575fcec48/torch/utils/checkpoint.py#L1381 during torch.compile/make_fx. The reason we do it is because it is hard to figure out RNG states in SAC case as many forward nodes can have custom RNG while in vanilla case, there is only one.

Fix Action

Fixed

PR fix notes

PR #2766: [graph_trainer] Add remat pass and torch.no_grad() execution to minimal_fx_tracer

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #2766
  • #2753
  • Annotate backward FX nodes with {"remat_pass_tag": "is_backward"} during _patch_engine_run_backward so remat_using_tags_for_fwd_loss_bwd_graph can identify the forward/backward boundary.
  • Apply remat_using_tags_for_fwd_loss_bwd_graph as a default post-trace pass. Nodes tagged PREFER_RECOMPUTE (from selective AC) are duplicated before backward and the forward copies are DCE'd, reducing peak memory.
  • Execute traced graph under torch.no_grad() since the graph already contains explicit backward ops. Without this, PyTorch builds a redundant autograd graph keeping all forward intermediates alive via grad_fn references.
  • Add test_llama_1b_peak_memory: verifies traced+AC peak memory is within 20% of eager+AC on Llama 1B (BS=2, seq=2048, bf16).

Changed files

  • torchtitan/experiments/graph_trainer/make_fx_tracer.py (modified, +16/-3)
  • torchtitan/experiments/graph_trainer/passes.py (modified, +41/-0)
  • torchtitan/experiments/graph_trainer/tests/test_trace_module.py (modified, +117/-0)
  • torchtitan/experiments/graph_trainer/trainer.py (modified, +4/-0)

Code Example

import torch
import torch.nn.functional as F
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.checkpoint import (
    checkpoint,
    create_selective_checkpoint_contexts,
    CheckpointPolicy,
)

# Single layer: mm → gelu → mm
def layer(x, w1, w2):
    return F.gelu(x @ w1) @ w2

def train(x, w1, w2, ctx_fn=None):
    if ctx_fn:
        y = checkpoint(layer, x, w1, w2, use_reentrant=False, context_fn=ctx_fn)
    else:
        y = checkpoint(layer, x, w1, w2, use_reentrant=False)
    loss = y.sum()
    return (loss, *torch.autograd.grad(loss, [w1, w2]))

x = torch.randn(2, 4)
w1 = torch.randn(4, 8, requires_grad=True)
w2 = torch.randn(8, 4, requires_grad=True)

# Vanilla AC: mm is recomputed
gm_ac = make_fx(lambda x, w1, w2: train(x, w1, w2),
                tracing_mode="fake", _allow_non_fake_inputs=True)(x, w1, w2)

# SAC: mm marked PREFER_RECOMPUTE but NOT recomputed
def sac_ctx():
    def policy(ctx, op, *a, **kw):
        if op == torch.ops.aten.mm.default:
            return CheckpointPolicy.PREFER_RECOMPUTE
        return CheckpointPolicy.MUST_SAVE
    return create_selective_checkpoint_contexts(policy)

gm_sac = make_fx(lambda x, w1, w2: train(x, w1, w2, ctx_fn=sac_ctx),
                 tracing_mode="fake", _allow_non_fake_inputs=True)(x, w1, w2)

for label, gm in [("VANILLA AC", gm_ac), ("SAC", gm_sac)]:
    mm = sum(1 for n in gm.graph.nodes if "mm" in str(n.target) and n.op == "call_function")
    print(f"{label}: mm={mm}")
    print(gm)
    print()

---

VANILLA AC: mm=6
<lambda>()



def forward(self, x_1, w1_1, w2_1):
    mm = torch.ops.aten.mm.default(x_1, w1_1)
    gelu = torch.ops.aten.gelu.default(mm);  mm = None
    mm_1 = torch.ops.aten.mm.default(gelu, w2_1);  gelu = None
    sum_1 = torch.ops.aten.sum.default(mm_1);  mm_1 = None
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
    expand = torch.ops.aten.expand.default(ones_like, [2, 4]);  ones_like = None
    mm_2 = torch.ops.aten.mm.default(x_1, w1_1);  w1_1 = None
    detach = torch.ops.aten.detach.default(mm_2)
    gelu_1 = torch.ops.aten.gelu.default(mm_2);  mm_2 = None
    detach_1 = torch.ops.aten.detach.default(w2_1);  w2_1 = None
    detach_2 = torch.ops.aten.detach.default(gelu_1);  gelu_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_4 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    t = torch.ops.aten.t.default(detach_4);  detach_4 = None
    mm_3 = torch.ops.aten.mm.default(t, expand);  t = None
    t_1 = torch.ops.aten.t.default(detach_3);  detach_3 = None
    mm_4 = torch.ops.aten.mm.default(expand, t_1);  expand = t_1 = None
    detach_5 = torch.ops.aten.detach.default(detach);  detach = None
    gelu_backward = torch.ops.aten.gelu_backward.default(mm_4, detach_5);  mm_4 = detach_5 = None
    t_2 = torch.ops.aten.t.default(x_1);  x_1 = None
    mm_5 = torch.ops.aten.mm.default(t_2, gelu_backward);  t_2 = gelu_backward = None
    return (sum_1, mm_5, mm_3)
    
# To see more debug info, please use `graph_module.print_readable()`

SAC: mm=5
<lambda>()



def forward(self, x_1, w1_1, w2_1):
    mm = torch.ops.aten.mm.default(x_1, w1_1);  w1_1 = None
    detach = torch.ops.aten.detach.default(mm)
    gelu = torch.ops.aten.gelu.default(mm);  mm = None
    detach_1 = torch.ops.aten.detach.default(gelu)
    mm_1 = torch.ops.aten.mm.default(gelu, w2_1);  gelu = None
    detach_2 = torch.ops.aten.detach.default(mm_1);  detach_2 = None
    sum_1 = torch.ops.aten.sum.default(mm_1);  mm_1 = None
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
    expand = torch.ops.aten.expand.default(ones_like, [2, 4]);  ones_like = None
    detach_3 = torch.ops.aten.detach.default(detach);  detach = None
    detach_4 = torch.ops.aten.detach.default(w2_1);  w2_1 = None
    detach_5 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_6 = torch.ops.aten.detach.default(detach_4);  detach_4 = None
    detach_7 = torch.ops.aten.detach.default(detach_5);  detach_5 = None
    t = torch.ops.aten.t.default(detach_7);  detach_7 = None
    mm_2 = torch.ops.aten.mm.default(t, expand);  t = None
    t_1 = torch.ops.aten.t.default(detach_6);  detach_6 = None
    mm_3 = torch.ops.aten.mm.default(expand, t_1);  expand = t_1 = None
    detach_8 = torch.ops.aten.detach.default(detach_3);  detach_3 = None
    gelu_backward = torch.ops.aten.gelu_backward.default(mm_3, detach_8);  mm_3 = detach_8 = None
    t_2 = torch.ops.aten.t.default(x_1);  x_1 = None
    mm_4 = torch.ops.aten.mm.default(t_2, gelu_backward);  t_2 = gelu_backward = None
    return (sum_1, mm_4, mm_2)
    
# To see more debug info, please use `graph_module.print_readable()`
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Consider following program:

import torch
import torch.nn.functional as F
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils.checkpoint import (
    checkpoint,
    create_selective_checkpoint_contexts,
    CheckpointPolicy,
)

# Single layer: mm → gelu → mm
def layer(x, w1, w2):
    return F.gelu(x @ w1) @ w2

def train(x, w1, w2, ctx_fn=None):
    if ctx_fn:
        y = checkpoint(layer, x, w1, w2, use_reentrant=False, context_fn=ctx_fn)
    else:
        y = checkpoint(layer, x, w1, w2, use_reentrant=False)
    loss = y.sum()
    return (loss, *torch.autograd.grad(loss, [w1, w2]))

x = torch.randn(2, 4)
w1 = torch.randn(4, 8, requires_grad=True)
w2 = torch.randn(8, 4, requires_grad=True)

# Vanilla AC: mm is recomputed
gm_ac = make_fx(lambda x, w1, w2: train(x, w1, w2),
                tracing_mode="fake", _allow_non_fake_inputs=True)(x, w1, w2)

# SAC: mm marked PREFER_RECOMPUTE but NOT recomputed
def sac_ctx():
    def policy(ctx, op, *a, **kw):
        if op == torch.ops.aten.mm.default:
            return CheckpointPolicy.PREFER_RECOMPUTE
        return CheckpointPolicy.MUST_SAVE
    return create_selective_checkpoint_contexts(policy)

gm_sac = make_fx(lambda x, w1, w2: train(x, w1, w2, ctx_fn=sac_ctx),
                 tracing_mode="fake", _allow_non_fake_inputs=True)(x, w1, w2)

for label, gm in [("VANILLA AC", gm_ac), ("SAC", gm_sac)]:
    mm = sum(1 for n in gm.graph.nodes if "mm" in str(n.target) and n.op == "call_function")
    print(f"{label}: mm={mm}")
    print(gm)
    print()

This outputs following:

VANILLA AC: mm=6
<lambda>()



def forward(self, x_1, w1_1, w2_1):
    mm = torch.ops.aten.mm.default(x_1, w1_1)
    gelu = torch.ops.aten.gelu.default(mm);  mm = None
    mm_1 = torch.ops.aten.mm.default(gelu, w2_1);  gelu = None
    sum_1 = torch.ops.aten.sum.default(mm_1);  mm_1 = None
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
    expand = torch.ops.aten.expand.default(ones_like, [2, 4]);  ones_like = None
    mm_2 = torch.ops.aten.mm.default(x_1, w1_1);  w1_1 = None
    detach = torch.ops.aten.detach.default(mm_2)
    gelu_1 = torch.ops.aten.gelu.default(mm_2);  mm_2 = None
    detach_1 = torch.ops.aten.detach.default(w2_1);  w2_1 = None
    detach_2 = torch.ops.aten.detach.default(gelu_1);  gelu_1 = None
    detach_3 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_4 = torch.ops.aten.detach.default(detach_2);  detach_2 = None
    t = torch.ops.aten.t.default(detach_4);  detach_4 = None
    mm_3 = torch.ops.aten.mm.default(t, expand);  t = None
    t_1 = torch.ops.aten.t.default(detach_3);  detach_3 = None
    mm_4 = torch.ops.aten.mm.default(expand, t_1);  expand = t_1 = None
    detach_5 = torch.ops.aten.detach.default(detach);  detach = None
    gelu_backward = torch.ops.aten.gelu_backward.default(mm_4, detach_5);  mm_4 = detach_5 = None
    t_2 = torch.ops.aten.t.default(x_1);  x_1 = None
    mm_5 = torch.ops.aten.mm.default(t_2, gelu_backward);  t_2 = gelu_backward = None
    return (sum_1, mm_5, mm_3)
    
# To see more debug info, please use `graph_module.print_readable()`

SAC: mm=5
<lambda>()



def forward(self, x_1, w1_1, w2_1):
    mm = torch.ops.aten.mm.default(x_1, w1_1);  w1_1 = None
    detach = torch.ops.aten.detach.default(mm)
    gelu = torch.ops.aten.gelu.default(mm);  mm = None
    detach_1 = torch.ops.aten.detach.default(gelu)
    mm_1 = torch.ops.aten.mm.default(gelu, w2_1);  gelu = None
    detach_2 = torch.ops.aten.detach.default(mm_1);  detach_2 = None
    sum_1 = torch.ops.aten.sum.default(mm_1);  mm_1 = None
    ones_like = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format)
    expand = torch.ops.aten.expand.default(ones_like, [2, 4]);  ones_like = None
    detach_3 = torch.ops.aten.detach.default(detach);  detach = None
    detach_4 = torch.ops.aten.detach.default(w2_1);  w2_1 = None
    detach_5 = torch.ops.aten.detach.default(detach_1);  detach_1 = None
    detach_6 = torch.ops.aten.detach.default(detach_4);  detach_4 = None
    detach_7 = torch.ops.aten.detach.default(detach_5);  detach_5 = None
    t = torch.ops.aten.t.default(detach_7);  detach_7 = None
    mm_2 = torch.ops.aten.mm.default(t, expand);  t = None
    t_1 = torch.ops.aten.t.default(detach_6);  detach_6 = None
    mm_3 = torch.ops.aten.mm.default(expand, t_1);  expand = t_1 = None
    detach_8 = torch.ops.aten.detach.default(detach_3);  detach_3 = None
    gelu_backward = torch.ops.aten.gelu_backward.default(mm_3, detach_8);  mm_3 = detach_8 = None
    t_2 = torch.ops.aten.t.default(x_1);  x_1 = None
    mm_4 = torch.ops.aten.mm.default(t_2, gelu_backward);  t_2 = gelu_backward = None
    return (sum_1, mm_4, mm_2)
    
# To see more debug info, please use `graph_module.print_readable()`

If you look at gelu_backward, vanilla AC uses recomputed mm while SAC doesn't. This is a bug because in SAC case, we fall into https://github.com/pytorch/pytorch/blob/6c67f765a2e3c6711720c74e9fb3d34575fcec48/torch/utils/checkpoint.py#L1381 during torch.compile/make_fx. The reason we do it is because it is hard to figure out RNG states in SAC case as many forward nodes can have custom RNG while in vanilla case, there is only one.

This is not immediately blocking for TorchTitan because we already plan to use graph based AC. But it does lead to annoying behaviour where user uses non-strict trace expecting AC/SAC will work, but it actually doesn't. I am thinking we should do following to PyTorch:

  1. If non-strict trace/torch.compile during SAC, early return. Today this behaviour doesn't exist as torch.compile still runs the AC machinery underhood while ignoring recompute ops.
  2. In non-strict case, warn that SAC is not guaranteed to work and recommend to use graph AC.

Versions

main

cc @soulitzer @ezyang @EikanWang @jgong5 @wenzhe-nrv @chauhang @penguinwu

extent analysis

TL;DR

To address the issue with SAC not working as expected in non-strict tracing, consider modifying PyTorch to early return when non-strict tracing is used with SAC, and provide a warning recommending the use of graph-based AC.

Guidance

  • Identify the conditions under which SAC is used with non-strict tracing and modify the code to early return in such cases.
  • Implement a warning system to notify users that SAC may not work as expected in non-strict tracing and recommend using graph-based AC instead.
  • Review the PyTorch codebase to ensure that the changes do not introduce any unintended behavior or conflicts with other features.
  • Consider adding documentation or tests to verify the correct behavior of SAC in different tracing modes.

Example

No specific code example is provided, as the issue requires modifications to the PyTorch codebase, which is not included in the problem description.

Notes

The proposed solution focuses on modifying PyTorch to handle the limitations of SAC in non-strict tracing. However, the actual implementation may require additional considerations, such as ensuring backward compatibility and handling edge cases.

Recommendation

Apply a workaround by modifying PyTorch to early return when non-strict tracing is used with SAC, and provide a warning recommending the use of graph-based AC. This approach acknowledges the limitations of SAC in non-strict tracing and guides users towards a more reliable solution.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix Interaction between Activation Checkpointing + make_fx based tracer [1 pull requests, 4 comments, 3 participants]