pytorch - 💡(How to fix) Fix torch.compile(backend=eager) with invoke_subgraph is broken on full fwd+bwd+loss [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178520Fetched 2026-04-08 01:35:51
View on GitHub
Comments
1
Participants
2
Timeline
30
Reactions
0
Timeline (top)
mentioned ×10subscribed ×10labeled ×7commented ×1

Error Message

20 loss = getitem_17.sum(); getitem_17 = None ---> 21 grad = torch.autograd.grad(loss, [g_model_modules_blocks_modules_0_modules_0_parameters_weight_, g_model_modules_blocks_modules_0_modules_0_parameters_bias_, g_model_modules_blocks_modules_0_modules_1_parameters_weight_, g_model_modules_blocks_modules_0_modules_1_parameters_bias_, g_model_modules_blocks_modules_1_modules_0_parameters_weight_, g_model_modules_blocks_modules_1_modules_0_parameters_bias_, g_model_modules_blocks_modules_1_modules_1_parameters_weight_, g_model_modules_blocks_modules_1_modules_1_parameters_bias_], allow_unused = True) 22 getitem_8 = grad[0] 23 getitem_9 = grad[1] File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/init.py:530, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads) 526 result = _vmap_internals.vmap(vjp, 0, 0, allow_none_pass_through=True)( 527 grad_outputs 528 ) 529 else: --> 530 result = engine_run_backward( 531 outputs, 532 grad_outputs, 533 retain_graph, 534 create_graph, 535 inputs, 536 allow_unused, 537 accumulate_grad=False, 538 ) 539 if materialize_grads: 540 if any( 541 result[i] is None and not is_tensor_like(inputs[i]) 542 for i in range(len(inputs)) 543 ): File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/graph.py:877, in _engine_run_backward(t_outputs, *args, **kwargs) 875 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 876 try: --> 877 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 878 t_outputs, *args, **kwargs 879 ) # Calls into the C++ engine to run the backward pass 880 finally: 881 if attach_logging_hooks: File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/function.py:317, in BackwardCFunction.apply(self, *args) 311 raise RuntimeError( 312 "Implementing both 'backward' and 'vjp' for a custom " 313 "Function is not allowed. You should only implement one " 314 "of them." 315 ) 316 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn --> 317 return user_fn(self, *args) File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/_higher_order_ops/invoke_subgraph.py:663, in InvokeSubgraphAutogradOp.backward(ctx, *grad_outs) 661 fake_mode = detect_fake_mode(primals + filtered_grad_outs) 662 if fake_mode is None: --> 663 raise AssertionError("fake_mode should be enabled for HOPs") 664 state = _CacheKeyState(fake_mode.shape_env) 666 tangent_metadata: list[object] = [] AssertionError: fake_mode should be enabled for HOPs

Fix Action

Fix / Workaround

torch._dynamo.reset() with torch._dynamo.config.patch(trace_autograd_ops=True): compiled = torch.compile(train_step, backend=backend, fullgraph=True) model.zero_grad() loss = compiled(x.clone()) print(loss)

Code Example

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._higher_order_ops.invoke_subgraph import mark_compile_region


@mark_compile_region
def block_fwd(
    x: torch.Tensor,
    w1: torch.Tensor, b1: torch.Tensor,
    w2: torch.Tensor, b2: torch.Tensor,
) -> torch.Tensor:
    h = F.relu(F.linear(x, w1, b1))
    return x + F.linear(h, w2, b2)


class Model(nn.Module):
    def __init__(self, dim: int, n_layers: int) -> None:
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(nn.Linear(dim, dim * 2), nn.Linear(dim * 2, dim))
            for _ in range(n_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            fc1, fc2 = blk[0], blk[1]
            x = block_fwd(x, fc1.weight, fc1.bias, fc2.weight, fc2.bias)
        return x


DIM = 4
N_LAYERS = 2
DEVICE = "cuda"

torch.manual_seed(0)
model = Model(DIM, N_LAYERS).to(DEVICE)
x = torch.randn(4, DIM, device=DEVICE)


def train_step(x: torch.Tensor) -> torch.Tensor:
    out = model(x)
    loss = out.sum()
    loss.backward()
    return loss.detach()


def backend(gm, example_inputs):
    print(gm.graph)
    return gm 


torch._dynamo.reset()
with torch._dynamo.config.patch(trace_autograd_ops=True):
    compiled = torch.compile(train_step, backend=backend, fullgraph=True)
    model.zero_grad()
    loss = compiled(x.clone())
    print(loss)

---

graph():
    %g_model_modules_blocks_modules_0_modules_0_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_0_parameters_weight_]
    %g_model_modules_blocks_modules_0_modules_0_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_0_parameters_bias_]
    %g_model_modules_blocks_modules_0_modules_1_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_1_parameters_weight_]
    %g_model_modules_blocks_modules_0_modules_1_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_1_parameters_bias_]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %g_model_modules_blocks_modules_1_modules_0_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_0_parameters_weight_]
    %g_model_modules_blocks_modules_1_modules_0_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_0_parameters_bias_]
    %g_model_modules_blocks_modules_1_modules_1_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_1_parameters_weight_]
    %g_model_modules_blocks_modules_1_modules_1_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_1_parameters_bias_]
    %subgraph_0 : [num_users=1] = get_attr[target=subgraph_0]
    %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %g_model_modules_blocks_modules_0_modules_0_parameters_weight_, %g_model_modules_blocks_modules_0_modules_0_parameters_bias_, %g_model_modules_blocks_modules_0_modules_1_parameters_weight_, %g_model_modules_blocks_modules_0_modules_1_parameters_bias_), kwargs = {})
    %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
    %subgraph_1 : [num_users=1] = get_attr[target=subgraph_1]
    %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_1, subgraph_1, %getitem_16, %g_model_modules_blocks_modules_1_modules_0_parameters_weight_, %g_model_modules_blocks_modules_1_modules_0_parameters_bias_, %g_model_modules_blocks_modules_1_modules_1_parameters_weight_, %g_model_modules_blocks_modules_1_modules_1_parameters_bias_), kwargs = {})
    %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
    %loss : [num_users=2] = call_method[target=sum](args = (%getitem_17,), kwargs = {})
    %grad : [num_users=8] = call_function[target=torch.autograd.grad](args = (%loss, [%g_model_modules_blocks_modules_0_modules_0_parameters_weight_, %g_model_modules_blocks_modules_0_modules_0_parameters_bias_, %g_model_modules_blocks_modules_0_modules_1_parameters_weight_, %g_model_modules_blocks_modules_0_modules_1_parameters_bias_, %g_model_modules_blocks_modules_1_modules_0_parameters_weight_, %g_model_modules_blocks_modules_1_modules_0_parameters_bias_, %g_model_modules_blocks_modules_1_modules_1_parameters_weight_, %g_model_modules_blocks_modules_1_modules_1_parameters_bias_]), kwargs = {allow_unused: True})
    %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 0), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 1), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 2), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 3), kwargs = {})
    %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 4), kwargs = {})
    %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 5), kwargs = {})
    %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 6), kwargs = {})
    %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 7), kwargs = {})
    %_set_grad_enabled : [num_users=0] = call_function[target=torch._C._set_grad_enabled](args = (False,), kwargs = {})
    %new_grad_strided : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_0_parameters_weight_,), kwargs = {})
    %copy_ : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided, %getitem_8), kwargs = {})
    %new_grad_strided_1 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_0_parameters_bias_,), kwargs = {})
    %copy__1 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_1, %getitem_9), kwargs = {})
    %new_grad_strided_2 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_1_parameters_weight_,), kwargs = {})
    %copy__2 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_2, %getitem_10), kwargs = {})
    %new_grad_strided_3 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_1_parameters_bias_,), kwargs = {})
    %copy__3 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_3, %getitem_11), kwargs = {})
    %new_grad_strided_4 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_0_parameters_weight_,), kwargs = {})
    %copy__4 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_4, %getitem_12), kwargs = {})
    %new_grad_strided_5 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_0_parameters_bias_,), kwargs = {})
    %copy__5 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_5, %getitem_13), kwargs = {})
    %new_grad_strided_6 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_1_parameters_weight_,), kwargs = {})
    %copy__6 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_6, %getitem_14), kwargs = {})
    %new_grad_strided_7 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_1_parameters_bias_,), kwargs = {})
    %copy__7 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_7, %getitem_15), kwargs = {})
    %_set_grad_enabled_1 : [num_users=0] = call_function[target=torch._C._set_grad_enabled](args = (True,), kwargs = {})
    %detach : [num_users=1] = call_method[target=detach](args = (%loss,), kwargs = {})
    return (detach, new_grad_strided, new_grad_strided_1, new_grad_strided_2, new_grad_strided_3, new_grad_strided_4, new_grad_strided_5, new_grad_strided_6, new_grad_strided_7)

---

20 loss = getitem_17.sum();  getitem_17 = None
---> 21 grad = torch.autograd.grad(loss, [g_model_modules_blocks_modules_0_modules_0_parameters_weight_, g_model_modules_blocks_modules_0_modules_0_parameters_bias_, g_model_modules_blocks_modules_0_modules_1_parameters_weight_, g_model_modules_blocks_modules_0_modules_1_parameters_bias_, g_model_modules_blocks_modules_1_modules_0_parameters_weight_, g_model_modules_blocks_modules_1_modules_0_parameters_bias_, g_model_modules_blocks_modules_1_modules_1_parameters_weight_, g_model_modules_blocks_modules_1_modules_1_parameters_bias_], allow_unused = True)
     22 getitem_8 = grad[0]
     23 getitem_9 = grad[1]
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/__init__.py:530, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    526     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
    527         grad_outputs_
    528     )
    529 else:
--> 530     result = _engine_run_backward(
    531         outputs,
    532         grad_outputs_,
    533         retain_graph,
    534         create_graph,
    535         inputs,
    536         allow_unused,
    537         accumulate_grad=False,
    538     )
    539 if materialize_grads:
    540     if any(
    541         result[i] is None and not is_tensor_like(inputs[i])
    542         for i in range(len(inputs))
    543     ):
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/graph.py:877, in _engine_run_backward(t_outputs, *args, **kwargs)
    875     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    876 try:
--> 877     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    878         t_outputs, *args, **kwargs
    879     )  # Calls into the C++ engine to run the backward pass
    880 finally:
    881     if attach_logging_hooks:
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/function.py:317, in BackwardCFunction.apply(self, *args)
    311     raise RuntimeError(
    312         "Implementing both 'backward' and 'vjp' for a custom "
    313         "Function is not allowed. You should only implement one "
    314         "of them."
    315     )
    316 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 317 return user_fn(self, *args)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/_higher_order_ops/invoke_subgraph.py:663, in InvokeSubgraphAutogradOp.backward(ctx, *grad_outs)
    661 fake_mode = detect_fake_mode(primals + filtered_grad_outs)
    662 if fake_mode is None:
--> 663     raise AssertionError("fake_mode should be enabled for HOPs")
    664 state = _CacheKeyState(fake_mode.shape_env)
    666 tangent_metadata: list[object] = []
AssertionError: fake_mode should be enabled for HOPs
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._higher_order_ops.invoke_subgraph import mark_compile_region


@mark_compile_region
def block_fwd(
    x: torch.Tensor,
    w1: torch.Tensor, b1: torch.Tensor,
    w2: torch.Tensor, b2: torch.Tensor,
) -> torch.Tensor:
    h = F.relu(F.linear(x, w1, b1))
    return x + F.linear(h, w2, b2)


class Model(nn.Module):
    def __init__(self, dim: int, n_layers: int) -> None:
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(nn.Linear(dim, dim * 2), nn.Linear(dim * 2, dim))
            for _ in range(n_layers)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for blk in self.blocks:
            fc1, fc2 = blk[0], blk[1]
            x = block_fwd(x, fc1.weight, fc1.bias, fc2.weight, fc2.bias)
        return x


DIM = 4
N_LAYERS = 2
DEVICE = "cuda"

torch.manual_seed(0)
model = Model(DIM, N_LAYERS).to(DEVICE)
x = torch.randn(4, DIM, device=DEVICE)


def train_step(x: torch.Tensor) -> torch.Tensor:
    out = model(x)
    loss = out.sum()
    loss.backward()
    return loss.detach()


def backend(gm, example_inputs):
    print(gm.graph)
    return gm 


torch._dynamo.reset()
with torch._dynamo.config.patch(trace_autograd_ops=True):
    compiled = torch.compile(train_step, backend=backend, fullgraph=True)
    model.zero_grad()
    loss = compiled(x.clone())
    print(loss)

This prints following graph:

graph():
    %g_model_modules_blocks_modules_0_modules_0_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_0_parameters_weight_]
    %g_model_modules_blocks_modules_0_modules_0_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_0_parameters_bias_]
    %g_model_modules_blocks_modules_0_modules_1_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_1_parameters_weight_]
    %g_model_modules_blocks_modules_0_modules_1_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_0_modules_1_parameters_bias_]
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %g_model_modules_blocks_modules_1_modules_0_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_0_parameters_weight_]
    %g_model_modules_blocks_modules_1_modules_0_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_0_parameters_bias_]
    %g_model_modules_blocks_modules_1_modules_1_parameters_weight_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_1_parameters_weight_]
    %g_model_modules_blocks_modules_1_modules_1_parameters_bias_ : torch.nn.parameter.Parameter [num_users=3] = placeholder[target=G_model_modules_blocks_modules_1_modules_1_parameters_bias_]
    %subgraph_0 : [num_users=1] = get_attr[target=subgraph_0]
    %invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_0, subgraph_0, %l_x_, %g_model_modules_blocks_modules_0_modules_0_parameters_weight_, %g_model_modules_blocks_modules_0_modules_0_parameters_bias_, %g_model_modules_blocks_modules_0_modules_1_parameters_weight_, %g_model_modules_blocks_modules_0_modules_1_parameters_bias_), kwargs = {})
    %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
    %subgraph_1 : [num_users=1] = get_attr[target=subgraph_1]
    %invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%subgraph_1, subgraph_1, %getitem_16, %g_model_modules_blocks_modules_1_modules_0_parameters_weight_, %g_model_modules_blocks_modules_1_modules_0_parameters_bias_, %g_model_modules_blocks_modules_1_modules_1_parameters_weight_, %g_model_modules_blocks_modules_1_modules_1_parameters_bias_), kwargs = {})
    %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
    %loss : [num_users=2] = call_method[target=sum](args = (%getitem_17,), kwargs = {})
    %grad : [num_users=8] = call_function[target=torch.autograd.grad](args = (%loss, [%g_model_modules_blocks_modules_0_modules_0_parameters_weight_, %g_model_modules_blocks_modules_0_modules_0_parameters_bias_, %g_model_modules_blocks_modules_0_modules_1_parameters_weight_, %g_model_modules_blocks_modules_0_modules_1_parameters_bias_, %g_model_modules_blocks_modules_1_modules_0_parameters_weight_, %g_model_modules_blocks_modules_1_modules_0_parameters_bias_, %g_model_modules_blocks_modules_1_modules_1_parameters_weight_, %g_model_modules_blocks_modules_1_modules_1_parameters_bias_]), kwargs = {allow_unused: True})
    %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 0), kwargs = {})
    %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 1), kwargs = {})
    %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 2), kwargs = {})
    %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 3), kwargs = {})
    %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 4), kwargs = {})
    %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 5), kwargs = {})
    %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 6), kwargs = {})
    %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%grad, 7), kwargs = {})
    %_set_grad_enabled : [num_users=0] = call_function[target=torch._C._set_grad_enabled](args = (False,), kwargs = {})
    %new_grad_strided : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_0_parameters_weight_,), kwargs = {})
    %copy_ : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided, %getitem_8), kwargs = {})
    %new_grad_strided_1 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_0_parameters_bias_,), kwargs = {})
    %copy__1 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_1, %getitem_9), kwargs = {})
    %new_grad_strided_2 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_1_parameters_weight_,), kwargs = {})
    %copy__2 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_2, %getitem_10), kwargs = {})
    %new_grad_strided_3 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_0_modules_1_parameters_bias_,), kwargs = {})
    %copy__3 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_3, %getitem_11), kwargs = {})
    %new_grad_strided_4 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_0_parameters_weight_,), kwargs = {})
    %copy__4 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_4, %getitem_12), kwargs = {})
    %new_grad_strided_5 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_0_parameters_bias_,), kwargs = {})
    %copy__5 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_5, %getitem_13), kwargs = {})
    %new_grad_strided_6 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_1_parameters_weight_,), kwargs = {})
    %copy__6 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_6, %getitem_14), kwargs = {})
    %new_grad_strided_7 : [num_users=2] = call_function[target=torch.empty_like](args = (%g_model_modules_blocks_modules_1_modules_1_parameters_bias_,), kwargs = {})
    %copy__7 : [num_users=0] = call_method[target=copy_](args = (%new_grad_strided_7, %getitem_15), kwargs = {})
    %_set_grad_enabled_1 : [num_users=0] = call_function[target=torch._C._set_grad_enabled](args = (True,), kwargs = {})
    %detach : [num_users=1] = call_method[target=detach](args = (%loss,), kwargs = {})
    return (detach, new_grad_strided, new_grad_strided_1, new_grad_strided_2, new_grad_strided_3, new_grad_strided_4, new_grad_strided_5, new_grad_strided_6, new_grad_strided_7)

With a crash later on:

     20 loss = getitem_17.sum();  getitem_17 = None
---> 21 grad = torch.autograd.grad(loss, [g_model_modules_blocks_modules_0_modules_0_parameters_weight_, g_model_modules_blocks_modules_0_modules_0_parameters_bias_, g_model_modules_blocks_modules_0_modules_1_parameters_weight_, g_model_modules_blocks_modules_0_modules_1_parameters_bias_, g_model_modules_blocks_modules_1_modules_0_parameters_weight_, g_model_modules_blocks_modules_1_modules_0_parameters_bias_, g_model_modules_blocks_modules_1_modules_1_parameters_weight_, g_model_modules_blocks_modules_1_modules_1_parameters_bias_], allow_unused = True)
     22 getitem_8 = grad[0]
     23 getitem_9 = grad[1]
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/__init__.py:530, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    526     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
    527         grad_outputs_
    528     )
    529 else:
--> 530     result = _engine_run_backward(
    531         outputs,
    532         grad_outputs_,
    533         retain_graph,
    534         create_graph,
    535         inputs,
    536         allow_unused,
    537         accumulate_grad=False,
    538     )
    539 if materialize_grads:
    540     if any(
    541         result[i] is None and not is_tensor_like(inputs[i])
    542         for i in range(len(inputs))
    543     ):
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/graph.py:877, in _engine_run_backward(t_outputs, *args, **kwargs)
    875     unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
    876 try:
--> 877     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    878         t_outputs, *args, **kwargs
    879     )  # Calls into the C++ engine to run the backward pass
    880 finally:
    881     if attach_logging_hooks:
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/autograd/function.py:317, in BackwardCFunction.apply(self, *args)
    311     raise RuntimeError(
    312         "Implementing both 'backward' and 'vjp' for a custom "
    313         "Function is not allowed. You should only implement one "
    314         "of them."
    315     )
    316 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 317 return user_fn(self, *args)
File /data/users/tmanlaibaatar/.bento/kernels/bento_kernel_pytorch/2799/bento_kernel_pytorch_binary-inplace#link-tree/torch/_higher_order_ops/invoke_subgraph.py:663, in InvokeSubgraphAutogradOp.backward(ctx, *grad_outs)
    661 fake_mode = detect_fake_mode(primals + filtered_grad_outs)
    662 if fake_mode is None:
--> 663     raise AssertionError("fake_mode should be enabled for HOPs")
    664 state = _CacheKeyState(fake_mode.shape_env)
    666 tangent_metadata: list[object] = []
AssertionError: fake_mode should be enabled for HOPs

As you can see the autograd.grad call is preserved. When we call the training step, we don't have AOTAutograd to properly compile the backward. InvokeSubgraphAutogradOp is designed to only work on fake tensors which gets supplied from AOTAutograd.

Versions

main

cc @chauhang @penguinwu @ydwu4 @bdhirsh @bobrenjc93 @aorenste

extent analysis

Fix Plan

To fix the issue, we need to ensure that AOTAutograd is properly enabled for the backward pass. We can achieve this by modifying the torch.compile call to include the aot_autograd option.

Here are the steps:

  • Modify the torch.compile call to include aot_autograd=True.
  • Update the backend function to handle the compiled graph.

Example code:

torch._dynamo.reset()
with torch._dynamo.config.patch(trace_autograd_ops=True):
    compiled = torch.compile(train_step, backend=backend, fullgraph=True, aot_autograd=True)
    model.zero_grad()
    loss = compiled(x.clone())
    print(loss)

Additionally, we need to ensure that the InvokeSubgraphAutogradOp is properly handled in the backend function. We can do this by checking if the graph contains any InvokeSubgraphAutogradOp nodes and handling them accordingly.

Verification

To verify that the fix worked, we can check the following:

  • The torch.compile call completes successfully without any errors.
  • The train_step function executes correctly and produces the expected output.
  • The InvokeSubgraphAutogradOp nodes are properly handled in the backend function.

Extra Tips

  • Make sure to update the torch version to the latest available version to ensure that the aot_autograd option is supported.
  • If you encounter any issues with the backend function, try printing the compiled graph to verify that it contains the expected nodes.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING