pytorch - 💡(How to fix) Fix `torch.compile` silently drops higher-order grad graph metadata from `autograd.grad(create_graph=True)` outputs [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#181581Fetched 2026-04-28 06:24:36
View on GitHub
Comments
1
Participants
2
Timeline
44
Reactions
0
Author
Timeline (top)
mentioned ×17subscribed ×17labeled ×9commented ×1

When using torch.autograd.grad(..., create_graph=True) inside a torch.compiled function, the returned gradient tensor can silently lose its higher-order gradient graph metadata for some ops (mm, matmul, addmm, relu, leaky_relu). In eager mode the same code returns a gradient tensor with requires_grad=True and a valid grad_fn, while the compiled version returns a tensor with requires_grad=False and grad_fn=None.

This is surprising because create_graph=True asks autograd to construct the derivative graph for higher-order differentiation. If this pattern is currently unsupported by torch.compile / AOTAutograd for some operators, it would be helpful to either preserve eager-compatible behavior or fail explicitly instead of returning a tensor that looks like a normal value but no longer participates in the expected gradient graph.

The silent loss of graph metadata breaks double/higher-order backward, which is required by:

  • MAML (Model-Agnostic Meta-Learning) — inner-loop gradient update must be differentiable
  • WGAN-GP — gradient penalty requires gradients of gradients
  • Physics-informed neural networks — PDE losses involve gradients of network outputs

This also reproduces with backend='aot_eager', so it does not appear to be Inductor-specific.

Error Message

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The issue only becomes visible when .backward() is called on the result. The forward computation and first-order gradient values appear correct; the surprising part is that requires_grad / grad_fn metadata is lost silently even though create_graph=True was requested.

Root Cause

This is surprising because create_graph=True asks autograd to construct the derivative graph for higher-order differentiation. If this pattern is currently unsupported by torch.compile / AOTAutograd for some operators, it would be helpful to either preserve eager-compatible behavior or fail explicitly instead of returning a tensor that looks like a normal value but no longer participates in the expected gradient graph.

Code Example

import torch

x = torch.randn(4, 8, device='cuda', requires_grad=True)
w = torch.randn(8, 8, device='cuda', requires_grad=True)

def f(x, w):
    y = torch.mm(x, w).sum()
    grad_x, = torch.autograd.grad(y, x, create_graph=True)
    return grad_x.sum()

# Eager: requires_grad=True, grad_fn=<SumBackward0>
r_eager = f(x, w)
print(r_eager.requires_grad)  # True 

# Compile: requires_grad=False, grad_fn=None
torch._dynamo.reset()
r_compile = torch.compile(f)(x.clone().detach().requires_grad_(True),
                              w.clone().detach().requires_grad_(True))
print(r_compile.requires_grad)  # False

# Consequence: higher-order backward fails later
r_compile.backward()
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

---

import torch
import torch.nn as nn

G = nn.Sequential(nn.Linear(16, 64), nn.ReLU(), nn.Linear(64, 64)).cuda()
D = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1)).cuda()

def gradient_penalty(D, real, fake):
    alpha = torch.rand(real.size(0), 1, device=real.device)
    interp = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interp = D(interp)
    grad, = torch.autograd.grad(d_interp.sum(), interp, create_graph=True)
    return ((grad.norm(2, dim=1) - 1) ** 2).mean()

real = torch.randn(8, 64, device='cuda')
fake = G(torch.randn(8, 16, device='cuda'))

# Eager: works
gp = gradient_penalty(D, real, fake.detach())
gp.backward()  # correct

# Compile: fails
torch._dynamo.reset()
gp = torch.compile(gradient_penalty)(D, real, fake.detach())
gp.backward()  #  RuntimeError

---

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

---

PyTorch: 2.13.0.dev20260425+cu126
Python: 3.11.15
OS: Linux-5.4.0-42-generic-x86_64-with-glibc2.31
CUDA: 12.6
GPU: Tesla T4
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Description

When using torch.autograd.grad(..., create_graph=True) inside a torch.compiled function, the returned gradient tensor can silently lose its higher-order gradient graph metadata for some ops (mm, matmul, addmm, relu, leaky_relu). In eager mode the same code returns a gradient tensor with requires_grad=True and a valid grad_fn, while the compiled version returns a tensor with requires_grad=False and grad_fn=None.

This is surprising because create_graph=True asks autograd to construct the derivative graph for higher-order differentiation. If this pattern is currently unsupported by torch.compile / AOTAutograd for some operators, it would be helpful to either preserve eager-compatible behavior or fail explicitly instead of returning a tensor that looks like a normal value but no longer participates in the expected gradient graph.

The silent loss of graph metadata breaks double/higher-order backward, which is required by:

  • MAML (Model-Agnostic Meta-Learning) — inner-loop gradient update must be differentiable
  • WGAN-GP — gradient penalty requires gradients of gradients
  • Physics-informed neural networks — PDE losses involve gradients of network outputs

This also reproduces with backend='aot_eager', so it does not appear to be Inductor-specific.

Minimal reproducer

import torch

x = torch.randn(4, 8, device='cuda', requires_grad=True)
w = torch.randn(8, 8, device='cuda', requires_grad=True)

def f(x, w):
    y = torch.mm(x, w).sum()
    grad_x, = torch.autograd.grad(y, x, create_graph=True)
    return grad_x.sum()

# Eager: requires_grad=True, grad_fn=<SumBackward0>
r_eager = f(x, w)
print(r_eager.requires_grad)  # True 

# Compile: requires_grad=False, grad_fn=None
torch._dynamo.reset()
r_compile = torch.compile(f)(x.clone().detach().requires_grad_(True),
                              w.clone().detach().requires_grad_(True))
print(r_compile.requires_grad)  # False

# Consequence: higher-order backward fails later
r_compile.backward()
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Real-world impact: WGAN-GP gradient penalty

import torch
import torch.nn as nn

G = nn.Sequential(nn.Linear(16, 64), nn.ReLU(), nn.Linear(64, 64)).cuda()
D = nn.Sequential(nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 1)).cuda()

def gradient_penalty(D, real, fake):
    alpha = torch.rand(real.size(0), 1, device=real.device)
    interp = (alpha * real + (1 - alpha) * fake).requires_grad_(True)
    d_interp = D(interp)
    grad, = torch.autograd.grad(d_interp.sum(), interp, create_graph=True)
    return ((grad.norm(2, dim=1) - 1) ** 2).mean()

real = torch.randn(8, 64, device='cuda')
fake = G(torch.randn(8, 16, device='cuda'))

# Eager: works
gp = gradient_penalty(D, real, fake.detach())
gp.backward()  # correct

# Compile: fails
torch._dynamo.reset()
gp = torch.compile(gradient_penalty)(D, real, fake.detach())
gp.backward()  #  RuntimeError

Affected ops (systematic test, 22 ops)

OpEager requires_gradCompile requires_gradStatus
torch.mmTrueFalsediffers from eager
x @ w (matmul 2D)TrueFalsediffers from eager
torch.addmmTrueFalsediffers from eager
torch.reluTrueFalsediffers from eager
F.leaky_reluTrueFalsediffers from eager
F.linear (no bias)TrueTrue✅ OK
F.linear (with bias)TrueTrue✅ OK
torch.sigmoidTrueTrue✅ OK
torch.tanhTrueTrue✅ OK
F.geluTrueTrue✅ OK
F.siluTrueTrue✅ OK
x * yTrueTrue✅ OK
x / yTrueTrue✅ OK
x ** 2TrueTrue✅ OK
x.exp()TrueTrue✅ OK
x.sin()TrueTrue✅ OK
F.layer_normTrueTrue✅ OK
F.softmaxTrueTrue✅ OK
torch.mvTrueTrue✅ OK
x.sum()FalseFalse✅ OK
x.mean()FalseFalse✅ OK
(x @ w).norm()TrueTrue✅ OK

Note: F.linear works correctly despite using mm/addmm internally — the decomposition path in AOT Autograd differs.

Trigger conditions

ConditionRequired?Notes
create_graph=TrueYescreate_graph=False correctly returns requires_grad=False
OpSpecificmm, matmul, addmm, relu, leaky_relu
BackendAnyReproduces with both inductor and aot_eager
DeviceAnyTested on CUDA (CPU blocked by unrelated C++20 compiler issue)
Input dimsAny2D matmul is sufficient to trigger

Error message

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The issue only becomes visible when .backward() is called on the result. The forward computation and first-order gradient values appear correct; the surprising part is that requires_grad / grad_fn metadata is lost silently even though create_graph=True was requested.

Versions

Versions

PyTorch: 2.13.0.dev20260425+cu126
Python: 3.11.15
OS: Linux-5.4.0-42-generic-x86_64-with-glibc2.31
CUDA: 12.6
GPU: Tesla T4

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste

extent analysis

TL;DR

The issue can be worked around by avoiding the use of torch.compile for functions that require higher-order gradient computations.

Guidance

  • Identify the specific ops that are causing the issue (mm, matmul, addmm, relu, leaky_relu) and consider using alternative implementations that do not rely on torch.compile.
  • Verify that the create_graph=True argument is being used correctly and that the expected gradient metadata is being preserved in eager mode.
  • Consider using a different backend or device to see if the issue is specific to the current configuration.
  • Test the code with a smaller set of inputs and ops to isolate the root cause of the issue.

Example

No code example is provided as the issue is related to the interaction between torch.compile and specific ops, and a simple code snippet may not accurately reproduce the issue.

Notes

The issue appears to be specific to the interaction between torch.compile and certain ops, and may not be related to the PyTorch version or device being used. However, the issue only becomes visible when .backward() is called on the result, and the forward computation and first-order gradient values appear correct.

Recommendation

Apply workaround: avoid using torch.compile for functions that require higher-order gradient computations, as it may silently lose gradient metadata for certain ops.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING