pytorch - ✅(Solved) Fix [Memory Leak] torch.compile'd flex_attention_backward retains ~12GB stale tensors after first execution [1 pull requests, 7 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177869Fetched 2026-04-08 01:03:10
View on GitHub
Comments
7
Participants
3
Timeline
59
Reactions
0
Timeline (top)
mentioned ×17subscribed ×17referenced ×11commented ×7

Error Message

Error logs

Fix Action

Fix / Workaround

Key evidence from CUDA memory traces

  1. First vs subsequent alloc patterns are identical — no extra alloc/free pairs, ruling out autotuning as the cause.
  2. The only call stack difference is _ops.py:520 (first call, _dispatch_cache miss) vs _ops.py:381 (subsequent calls, cache hit).
  3. Stale tensors are invisible to Python GCgc.get_objects() scan finds no references to these tensors; they appear held by a C++/Python hybrid structure.
  4. torch.cuda.empty_cache() has no effect — the tensors are still referenced, not just cached.
  5. Resetting the compiled function doesn't help — setting it to None + gc.collect() frees the old 12GB, but the next torch.compile + first execution leaks another 12GB.

Questions

  1. Where exactly are the first-execution output tensors being retained? (Suspected: _dispatch_cache closure, CompiledFunction.compiled_bw class attribute chain, or Inductor codegen closure)
  2. Is this specific to torch.compile wrapping HigherOrderOperators, or a general torch.compile issue?
  3. Is there a recommended workaround that actually releases the stale memory?

PR fix notes

PR #178357: Let FlexBackward be standalone invoakable

Description (problem / solution / changelog)

Stack from ghstack (oldest at bottom):

  • -> #178357

Sumary

This is a kind of a big PR so hopefully ill make it easier to review

TDLR: This lets users invoke flex-backward individually without needing to go through forward. Issue: https://github.com/pytorch/pytorch/issues/177869 Usage:

	
        def score_mod(score, b, h, m, n):
            return score

        def mask_mod(b, h, m, n):
            return m + mask_buffer >= n

        block_mask = create_block_mask(
            mask_mod,
            B=2,
            H=2,
            Q_LEN=128,
            KV_LEN=128,
            device=device,
        )
		 out, logsumexp, _ = flex_attention_hop(
	            q,
	            k,
	            v,
	            score_mod,
	            block_mask.as_tuple(),
	            scale,
	            {},
	        )

        @torch.compile(fullgraph=True)
        def compiled_bw(query, key, value, fwd_out, lse, grad_out):
            return torch.ops.higher_order.flex_attention_backward(
                query,
                key,
                value,
                fwd_out,
                lse,
                grad_out,
                None,
                score_mod,
                None,
                block_mask.as_tuple(),
                scale,
                {},
                (),
                (),
            )

How does it work

Existing

Unless we are funning compiled backward we never actually hit the VT for the backward impl. What happens is that we speculate the fwd then manually invoke aot autograd machinery to create the joint and setup our autograd not. AotAutograd proper then kicks in and we compile the backward hop avoiding dynamo.

If we are doing compiled backward we do in fact run this VT but its minimal and mostly setup in either fwd hop VT / autograd node.

New

We now will recreate the joint_graph in the VT for flex-backward, here: https://github.com/pytorch/pytorch/pull/178357/changes#diff-d6c3a57794a9a8a79eacedc57d68ff631fe80776f8701398eb3ef349973b033bR4161 which is basically exact copy of what we do in the hop autograd func.

Claude decided to add some more error checking and im cool with it. The main new logic is from figuring out which path we are in the, the aotautgrad path or the user direct call path;

The default path will never see this;

The default compiled backward path will use the traced score_mod from dynamo and produce be an fx.GM the new user flow will be a non traced python func. That is what we use to determine if we should create joint graph / do all the proxing.

One big gotcha

This is is fundamentally a very pro user API and im hesitant to make it too nice, since I really want users to know what they're doing. For instance, the output of flex attention (the blessed op) is in regular ln space, but the output from the hop is in log2 space, and the expected lse of flex attention backward is also in log2 space.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela @Chillee @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

Changed files

  • test/dynamo/test_higher_order_ops.py (modified, +0/-1)
  • test/inductor/test_flex_attention.py (modified, +204/-0)
  • torch/_dynamo/variables/higher_order_ops.py (modified, +222/-1)
  • torch/_higher_order_ops/flex_attention.py (modified, +14/-1)
  • torch/_inductor/kernel/flex/flex_attention.py (modified, +4/-0)
  • torch/testing/_internal/hop_db.py (modified, +128/-3)

Code Example

import torch
from torch._higher_order_ops.flex_attention
import flex_attention_backward
def _flex_attn_bw_wrapper(*args, **kwargs):
    return flex_attention_backward(*args, **kwargs)

compiled_bw = torch.compile(_flex_attn_bw_wrapper)
# First call: compile + execute → dk/dv from this call are never freed
dq, dk, dv, _ = compiled_bw(query, key, value, out, logsumexp, grad_out, grad_logsumexp, fw_graph, joint_graph, block_mask, scale, kernel_options,score_mod_other_buffers, mask_mod_other_buffers)

# After this call completes (and dq/dk/dv go out of scope),
# ~12GB (first execution's dk + dv) remains permanently allocated.
# Subsequent calls do NOT leak additional memory.
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

When flex_attention_backward is wrapped with torch.compile(), the first execution of the compiled function permanently retains its output tensors (dk, dv) in GPU memory (~12GB). Subsequent calls do not leak — only the first execution does. This causes OOM on memory-constrained setups.<br/><br/>

Key evidence from CUDA memory traces

  1. First vs subsequent alloc patterns are identical — no extra alloc/free pairs, ruling out autotuning as the cause.
  2. The only call stack difference is _ops.py:520 (first call, _dispatch_cache miss) vs _ops.py:381 (subsequent calls, cache hit).
  3. Stale tensors are invisible to Python GCgc.get_objects() scan finds no references to these tensors; they appear held by a C++/Python hybrid structure.
  4. torch.cuda.empty_cache() has no effect — the tensors are still referenced, not just cached.
  5. Resetting the compiled function doesn't help — setting it to None + gc.collect() frees the old 12GB, but the next torch.compile + first execution leaks another 12GB.

Minimal reproduction pattern

import torch
from torch._higher_order_ops.flex_attention
import flex_attention_backward
def _flex_attn_bw_wrapper(*args, **kwargs):
    return flex_attention_backward(*args, **kwargs)

compiled_bw = torch.compile(_flex_attn_bw_wrapper)
# First call: compile + execute → dk/dv from this call are never freed
dq, dk, dv, _ = compiled_bw(query, key, value, out, logsumexp, grad_out, grad_logsumexp, fw_graph, joint_graph, block_mask, scale, kernel_options,score_mod_other_buffers, mask_mod_other_buffers)

# After this call completes (and dq/dk/dv go out of scope),
# ~12GB (first execution's dk + dv) remains permanently allocated.
# Subsequent calls do NOT leak additional memory.

Context: This pattern is used in ring-attention / sequence-parallel implementations where torch.compile wrapping is needed for performance.

Environment

  • PyTorch: 2.9.1 (nightly)
  • cuda: 12.8

Questions

  1. Where exactly are the first-execution output tensors being retained? (Suspected: _dispatch_cache closure, CompiledFunction.compiled_bw class attribute chain, or Inductor codegen closure)
  2. Is this specific to torch.compile wrapping HigherOrderOperators, or a general torch.compile issue?
  3. Is there a recommended workaround that actually releases the stale memory?

Error logs

No response

Versions

PyTorch version: 2.9.1+cu128

Is debug build: False

CUDA used to build PyTorch: 12.8

ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)

GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0

Clang version: Could not collect

CMake version: version 4.1.2

Libc version: glibc-2.35

Python version: 3.11.14 | packaged by conda-forge | (main, Oct 22 2025, 22:46:25) [GCC 14.3.0] (64-bit runtime)

Python platform: Linux-5.4.241-1-x86_64-with-glibc2.35

Is CUDA available: True

CUDA runtime version: 12.8.93

CUDA_MODULE_LOADING set to:

GPU models and configuration:

GPU 0: NVIDIA H20

GPU 1: NVIDIA H20

GPU 2: NVIDIA H20

GPU 3: NVIDIA H20

GPU 4: NVIDIA H20

GPU 5: NVIDIA H20

GPU 6: NVIDIA H20

GPU 7: NVIDIA H20

Nvidia driver version: 535.247.01

cuDNN version: Probably one of the following:

/usr/lib/x86_64-linux-gnu/libcudnn.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.17.1

/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.17.1

Is XPU available: False

HIP runtime version: N/A

MIOpen runtime version: N/A

Is XNNPACK available: True

Caching allocator config: N/A

Versions of relevant libraries:

[pip3] numpy==2.3.4

[pip3] nvidia-cublas-cu12==12.8.4.1

[pip3] nvidia-cuda-cupti-cu12==12.8.90

[pip3] nvidia-cuda-nvrtc-cu12==12.8.93

[pip3] nvidia-cuda-runtime-cu12==12.8.90

[pip3] nvidia-cudnn-cu12==9.10.2.21

[pip3] nvidia-cudnn-frontend==1.19.1

[pip3] nvidia-cufft-cu12==11.3.3.83

[pip3] nvidia-curand-cu12==10.3.9.90

[pip3] nvidia-cusolver-cu12==11.7.3.90

[pip3] nvidia-cusparse-cu12==12.5.8.93

[pip3] nvidia-cusparselt-cu12==0.7.1

[pip3] nvidia-nccl-cu12==2.27.5

[pip3] nvidia-nvjitlink-cu12==12.8.93

[pip3] nvidia-nvtx-cu12==12.8.90

[pip3] optree==0.17.0

[pip3] torch==2.9.1+cu128

[pip3] torch_c_dlpack_ext==0.1.5

[pip3] torch_memory_saver==0.0.9

[pip3] torchao==0.9.0

[pip3] torchaudio==2.9.1+cu128

[pip3] torchcodec==0.7.0

[pip3] torchelastic==0.2.2

[pip3] torchvision==0.24.1+cu128

[pip3] triton==3.5.1

[conda] numpy 2.3.4 py311h2e04523_0 conda-forge

[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi

[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi

[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi

[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi

[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi

[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi

[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi

[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi

[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi

[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi

[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi

[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi

[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi

[conda] optree 0.17.0 pypi_0 pypi

[conda] torch 2.9.1+cu128 pypi_0 pypi

[conda] torchaudio 2.9.1+cu128 pypi_0 pypi

[conda] torchelastic 0.2.2 pypi_0 pypi

[conda] torchvision 0.24.1+cu128 pypi_0 pypi

[conda] triton 3.5.1 pypi_0 pypi

cc @chauhang @penguinwu @Chillee @drisspg @yanboliang @BoyuanFeng @liangel-02 @howardzhang-cv

extent analysis

Fix Plan

To address the memory leak issue when using torch.compile() with flex_attention_backward, we can try the following steps:

  • Clear the cache: After the first execution, try clearing the cache using torch.cuda.empty_cache() to see if it releases the retained memory.
  • Delete the compiled function: After the first execution, try deleting the compiled function using del compiled_bw to see if it releases the retained memory.
  • Use a context manager: Use a context manager to ensure that the compiled function is properly cleaned up after use.

Here's an example code snippet that demonstrates these steps:

import torch
from torch._higher_order_ops.flex_attention import flex_attention_backward

def _flex_attn_bw_wrapper(*args, **kwargs):
    return flex_attention_backward(*args, **kwargs)

# Create a context manager to clean up the compiled function
class CompiledFunctionContext:
    def __init__(self, func):
        self.func = torch.compile(func)

    def __enter__(self):
        return self.func

    def __exit__(self, exc_type, exc_val, exc_tb):
        del self.func
        torch.cuda.empty_cache()

# Use the context manager to compile and execute the function
with CompiledFunctionContext(_flex_attn_bw_wrapper) as compiled_bw:
    dq, dk, dv, _ = compiled_bw(query, key, value, out, logsumexp, grad_out, grad_logsumexp, fw_graph, joint_graph, block_mask, scale, kernel_options, score_mod_other_buffers, mask_mod_other_buffers)

Verification

To verify that the fix worked, you can monitor the GPU memory usage before and after the first execution of the compiled function. You can use tools like nvidia-smi or torch.cuda.memory_stats() to monitor the memory usage.

Extra Tips

  • Make sure to properly clean up any unnecessary tensors or variables to avoid memory leaks.
  • Consider using a memory profiler to identify any other potential memory leaks in your code.
  • If the issue persists, try updating to the latest version of PyTorch or CUDA to see if it resolves the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [Memory Leak] torch.compile'd flex_attention_backward retains ~12GB stale tensors after first execution [1 pull requests, 7 comments, 3 participants]