pytorch - ✅(Solved) Fix [COMPILE] torch compile is broken with custom ops with completely incorrect outputs most of the times in PyTorch 2.11 [2 pull requests, 2 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180642Fetched 2026-04-17 08:25:47
View on GitHub
Comments
2
Participants
1
Timeline
123
Reactions
1
Participants
Timeline (top)
mentioned ×56subscribed ×56labeled ×9commented ×2

Fix Action

Fixed

PR fix notes

PR #180670: [dynamo] Prune aliased autograd.Function side-effect outputs

Description (problem / solution / changelog)

Fix #180642

Summary

  1. What is the root cause problem Dynamo's autograd.Function forward tracing can surface extra hidden outputs when side effects are allowed. In-place forward ops can create hidden outputs that alias the real forward outputs, and those aliased extras get wired into autograd_function_apply as distinct differentiable outputs.

  2. What is the proposed fix Prune hidden side-effect outputs that alias the real forward outputs before wiring the autograd.Function backward graph, and add a regression test for the in-place forward mutation repro from the issue.

  3. Why the proposed fix is the right long term fix Those extra side-effect outputs are an internal tracing detail, not user-visible forward results. Deduplicating aliased extras preserves side-effect support while preventing autograd from seeing duplicate aliased outputs that can silently route gradients incorrectly.

Testing

  • python test/dynamo/test_autograd_function.py -k test_inplace_forward_mutation_keeps_correct_grad
  • python test/dynamo/test_autograd_function.py -k test_aliasing_output
  • python test/dynamo/test_autograd_function.py -k test_nonlocal_list_mutation_in_autograd_function
  • python test/dynamo/test_autograd_function.py -k test_rewired_bwd_output
  • python test/dynamo/test_autograd_function.py -k test_udf_output

Drafted via Codex, published after manual review by @bobrenjc93

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98

Changed files

  • test/dynamo/test_autograd_function.py (modified, +25/-0)
  • torch/_dynamo/variables/higher_order_ops.py (modified, +93/-10)

PR #180675: Add test for autograd.Function fused fwd/bwd pattern (issue #180642)

Description (problem / solution / changelog)

The fix for this issue landed in #177368, but no test was added at the time to guard against future regressions. This adds a test covering the fused forward/backward pattern from the issue, an autograd.Function that pre-allocates gradient buffers via in-place ops and returns multiple outputs alongside saved tensors. Fixes #180642.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @kadeng @chauhang @amjames @Lucaskabela @jataylo @azahed98

Changed files

  • test/dynamo/test_autograd_function.py (modified, +39/-0)

Code Example

graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %l_w_ : torch.Tensor [num_users=3] = placeholder[target=L_W_]
    %l_y_ : torch.Tensor [num_users=1] = placeholder[target=L_y_]
    %l : [num_users=2] = call_function[target=torch.zeros](args = ((),), kwargs = {device: cuda:0, dtype: torch.float32})
    %dx : [num_users=2] = call_function[target=torch.empty_like](args = (%l_x_,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %dW : [num_users=2] = call_function[target=torch.zeros_like](args = (%l_w_,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %_x : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(0, 3150, None)), kwargs = {})
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%l_w_, T), kwargs = {})
    %_h : [num_users=2] = call_function[target=operator.matmul](args = (%_x, %getattr_1), kwargs = {})
    %_dh : [num_users=3] = call_function[target=torch.empty_like](args = (%_h,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %_y : [num_users=1] = call_function[target=operator.getitem](args = (%l_y_, slice(0, 3150, None)), kwargs = {})
    %_cross_entropy_forward_backward_triton_default : [num_users=0] = call_function[target=torch.ops.xma._cross_entropy_forward_backward_triton.default](args = (), kwargs = {x: %_h, labels: %_y, loss: %l, x_grad: %_dh, logits_multiplier: None, reduction: sum})
    %matmul_1 : [num_users=1] = call_function[target=operator.matmul](args = (%_dh, %l_w_), kwargs = {})
    %setitem : [num_users=0] = call_function[target=operator.setitem](args = (%dx, slice(0, 3150, None), %matmul_1), kwargs = {})
    %getattr_2 : [num_users=1] = call_function[target=builtins.getattr](args = (%_dh, T), kwargs = {})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%dW, %getattr_2, %_x), kwargs = {alpha: 1, beta: 1, out: %dW})
    %l_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%l, 3150), kwargs = {})
    %dx_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%dx, 3150), kwargs = {})
    %dW_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%dW, 3150), kwargs = {})
    return (dx_1, dW_1, l_1)

---

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_labels_ : torch.Tensor [num_users=1] = placeholder[target=L_labels_]
    %l_weight_ : torch.Tensor [num_users=2] = placeholder[target=L_weight_]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%l_weight_, T), kwargs = {})
    %fwd_body_0 : [num_users=1] = get_attr[target=fwd_body_0]
    %bwd_body_0 : [num_users=1] = get_attr[target=bwd_body_0]
    %autograd_function_apply : [num_users=1] = call_function[target=torch.ops.higher_order.autograd_function_apply](args = (%fwd_body_0, %bwd_body_0, %l_x_, %l_weight_, %getattr_1, %l_labels_), kwargs = {non_differentiable_idx: [], saved_for_backward_idx: [0, 1]})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%autograd_function_apply, 0), kwargs = {})
    return (getitem,)

---

What happened                                                                                                                                                                                                                  
                  
  Your custom torch.autograd.Function.forward has in-place side effects — it calls torch.ops.xma._cross_entropy_forward_backward_triton with x_grad=_dh (writing gradients into a pre-allocated buffer), plus setitem/addmm      
  in-place.
                                                                                                                                                                                                                                 
  Before aa23761208c: allow_side_effects=False caused speculate_subgraph to raise Unsupported when it hit those in-place ops. Dynamo graph-broke at Function.apply and fell back to a different path — your fwd+bwd computation  
  was traced directly, producing the correct flat graph with all ops inlined.
                                                                                                                                                                                                                                 
  After aa23761208c: The forward traces successfully with side effects allowed. Dynamo now wraps it with autograd_function_apply HOP (with fwd_body_0/bwd_body_0 subgraphs), which is the wrong graph you're seeing. The         
  saved_for_backward_idx: [0, 1] was then added on top by d0379e1bfaa (Jan 7, 2026).
                                                                                                                                                                                                                                 
  This commit is part of the broader rehaul 5cf15aef144 (#166788), which was co-authored by Animesh Jain and approved by @zou3519
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

[COMPILE] torch compile is broken with custom ops with completely incorrect outputs most of the times in PyTorch 2.11 This results in wrong outputs.

Versions

<img width="928" height="378" alt="Image" src="https://github.com/user-attachments/assets/15602954-8293-4c78-850d-3dd1b210d7a4" /> adding the print statement on line 51 gives correct FX graph: === FX Graph ===
graph():
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %l_w_ : torch.Tensor [num_users=3] = placeholder[target=L_W_]
    %l_y_ : torch.Tensor [num_users=1] = placeholder[target=L_y_]
    %l : [num_users=2] = call_function[target=torch.zeros](args = ((),), kwargs = {device: cuda:0, dtype: torch.float32})
    %dx : [num_users=2] = call_function[target=torch.empty_like](args = (%l_x_,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %dW : [num_users=2] = call_function[target=torch.zeros_like](args = (%l_w_,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %_x : [num_users=2] = call_function[target=operator.getitem](args = (%l_x_, slice(0, 3150, None)), kwargs = {})
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%l_w_, T), kwargs = {})
    %_h : [num_users=2] = call_function[target=operator.matmul](args = (%_x, %getattr_1), kwargs = {})
    %_dh : [num_users=3] = call_function[target=torch.empty_like](args = (%_h,), kwargs = {dtype: None, memory_format: torch.contiguous_format})
    %_y : [num_users=1] = call_function[target=operator.getitem](args = (%l_y_, slice(0, 3150, None)), kwargs = {})
    %_cross_entropy_forward_backward_triton_default : [num_users=0] = call_function[target=torch.ops.xma._cross_entropy_forward_backward_triton.default](args = (), kwargs = {x: %_h, labels: %_y, loss: %l, x_grad: %_dh, logits_multiplier: None, reduction: sum})
    %matmul_1 : [num_users=1] = call_function[target=operator.matmul](args = (%_dh, %l_w_), kwargs = {})
    %setitem : [num_users=0] = call_function[target=operator.setitem](args = (%dx, slice(0, 3150, None), %matmul_1), kwargs = {})
    %getattr_2 : [num_users=1] = call_function[target=builtins.getattr](args = (%_dh, T), kwargs = {})
    %addmm : [num_users=0] = call_function[target=torch.addmm](args = (%dW, %getattr_2, %_x), kwargs = {alpha: 1, beta: 1, out: %dW})
    %l_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%l, 3150), kwargs = {})
    %dx_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%dx, 3150), kwargs = {})
    %dW_1 : [num_users=1] = call_function[target=operator.itruediv](args = (%dW, 3150), kwargs = {})
    return (dx_1, dW_1, l_1)

and without this I get the following graph: === FX Graph ===

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l_labels_ : torch.Tensor [num_users=1] = placeholder[target=L_labels_]
    %l_weight_ : torch.Tensor [num_users=2] = placeholder[target=L_weight_]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%l_weight_, T), kwargs = {})
    %fwd_body_0 : [num_users=1] = get_attr[target=fwd_body_0]
    %bwd_body_0 : [num_users=1] = get_attr[target=bwd_body_0]
    %autograd_function_apply : [num_users=1] = call_function[target=torch.ops.higher_order.autograd_function_apply](args = (%fwd_body_0, %bwd_body_0, %l_x_, %l_weight_, %getattr_1, %l_labels_), kwargs = {non_differentiable_idx: [], saved_for_backward_idx: [0, 1]})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%autograd_function_apply, 0), kwargs = {})
    return (getitem,)

@claude's suggestions:

  What happened                                                                                                                                                                                                                  
                  
  Your custom torch.autograd.Function.forward has in-place side effects — it calls torch.ops.xma._cross_entropy_forward_backward_triton with x_grad=_dh (writing gradients into a pre-allocated buffer), plus setitem/addmm      
  in-place.
                                                                                                                                                                                                                                 
  Before aa23761208c: allow_side_effects=False caused speculate_subgraph to raise Unsupported when it hit those in-place ops. Dynamo graph-broke at Function.apply and fell back to a different path — your fwd+bwd computation  
  was traced directly, producing the correct flat graph with all ops inlined.
                                                                                                                                                                                                                                 
  After aa23761208c: The forward traces successfully with side effects allowed. Dynamo now wraps it with autograd_function_apply HOP (with fwd_body_0/bwd_body_0 subgraphs), which is the wrong graph you're seeing. The         
  saved_for_backward_idx: [0, 1] was then added on top by d0379e1bfaa (Jan 7, 2026).
                                                                                                                                                                                                                                 
  This commit is part of the broader rehaul 5cf15aef144 (#166788), which was co-authored by Animesh Jain and approved by @zou3519

Repro: you can run https://github.com/open-lm-engine/accelerated-model-architectures/blob/main/tests/functional/fused_linear_cross_entropy_test.py with only the torch.compile testcases!

This happens with custom ops with autograd functions

cc @ezyang @gchanan @kadeng @msaroufim @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @amjames @Lucaskabela @jataylo @azahed98 @bdhirsh @aorenste

extent analysis

TL;DR

The issue can be resolved by modifying the custom torch.autograd.Function.forward to avoid in-place side effects.

Guidance

  • Identify and refactor the custom torch.autograd.Function.forward to eliminate in-place operations, such as writing gradients into pre-allocated buffers or using setitem/addmm in-place.
  • Verify that the forward function is correctly tracing with side effects allowed by checking the FX graph.
  • Run the repro test case (https://github.com/open-lm-engine/accelerated-model-architectures/blob/main/tests/functional/fused_linear_cross_entropy_test.py) with only the torch.compile testcases to ensure the issue is resolved.
  • Consider reviewing the broader rehaul commit (5cf15aef144) and its implications on custom ops with autograd functions.

Example

No code snippet is provided as the issue requires modification of the custom torch.autograd.Function.forward, which is not explicitly shown in the issue.

Notes

The issue seems to be related to the changes introduced in commit aa23761208c, which allowed side effects in the forward function. The fix requires careful examination of the custom autograd functions to avoid in-place operations.

Recommendation

Apply workaround: Modify the custom torch.autograd.Function.forward to avoid in-place side effects, as this is the most likely cause of the incorrect FX graph.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [COMPILE] torch compile is broken with custom ops with completely incorrect outputs most of the times in PyTorch 2.11 [2 pull requests, 2 comments, 1 participants]