pytorch - ✅(Solved) Fix torch.func transforms silently return zeros under inference_mode [2 pull requests, 4 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177318Fetched 2026-04-08 00:42:11
View on GitHub
Comments
4
Participants
3
Timeline
86
Reactions
0
Timeline (top)
mentioned ×29subscribed ×29referenced ×11labeled ×8

Error Message

when called under torch.inference_mode(). No error is raised. an error.

Root Cause

Functorch transforms set up their own autograd context on entry: enable_grad() for grad/vjp, _set_fwd_grad_enabled(True) for jvp (see NOTE [grad and vjp interaction with no_grad]). However, inference_mode also excludes autograd dispatch keys from the TLS dispatch key set, and the mode-enabling calls don't undo that. So autograd dispatch keys remain excluded and differentiation silently produces zeros.

The transforms already correctly handle no_grad — this is the same class of issue with a different TLS mechanism.

Fix Action

Fix / Workaround

Functorch transforms set up their own autograd context on entry: enable_grad() for grad/vjp, _set_fwd_grad_enabled(True) for jvp (see NOTE [grad and vjp interaction with no_grad]). However, inference_mode also excludes autograd dispatch keys from the TLS dispatch key set, and the mode-enabling calls don't undo that. So autograd dispatch keys remain excluded and differentiation silently produces zeros.

PR fix notes

PR #177479: fix(func): raise explicit error for grad/vjp/jvp in inference_mode

Description (problem / solution / changelog)

Summary

Prevent torch.func.{grad, vjp, jvp} from silently returning incorrect zero results when called under torch.inference_mode() by raising an explicit error.

What does this PR do?

Background

Currently, gradient transforms (grad/vjp/jvp) in torch.func fail silently (return zeros) when run in inference_mode():

  • Root cause: inference_mode() excludes autograd dispatch keys from TLS, and functorch's internal enable_grad() does not undo this exclusion
  • This behavior is unsafe (hard to debug) and inconsistent with PyTorch's "fail fast" philosophy

Solution

Add a check at the C++ entry of grad/vjp/jvp transforms:

  1. Detect if the current context is inference_mode() (via c10::InferenceMode::is_enabled())
  2. Raise a TORCH_CHECK with clear error message if true (instead of silent failure)
  3. Error message example: torch.func.{grad, vjp, jvp} are not supported in torch.inference_mode(). Please exit inference_mode before calling these transforms, as autograd is required for differentiation.

Rationale

  • inference_mode() is explicitly designed for non-differentiable inference workloads — allowing gradient transforms to run here is semantically inconsistent
  • Explicit error > silent failure: aligns with PyTorch's design principles and avoids hard-to-debug user issues
  • Consistent with existing handling of no_grad() (functorch already handles no_grad() correctly, this extends the same "fail fast" logic to inference_mode())

Test Plan

Unit Test

  • Unit test for grad/vjp/jvp under inference_mode(): verify that an error is raised (not zeros returned) and whether results are correct without inference_mode()
  • Test location: test/functorch/test_eager_transforms.py
  • add new test function named test_inference_mode_gradient_transforms in class TestGradTransform

User Test

<img width="945" height="118" alt="result" src="https://github.com/user-attachments/assets/bdb05c53-d442-4596-ae95-906f0f0d29b1" />

Related Issues

Fixes #177318

Changed files

  • test/functorch/test_eager_transforms.py (modified, +46/-0)
  • torch/csrc/functorch/init.cpp (modified, +14/-0)

PR #177596: [functorch] Fix grad/vjp/jvp returning zeros under inference_mode

Description (problem / solution / changelog)

Fixes #177318

inference_mode excludes autograd dispatch keys from TLS, causing functorch TensorWrappers to lack autograd metadata. The transforms already handle no_grad by saving prev_grad_mode and enabling grad; this extends the same treatment to inference_mode for grad, vjp,and jvp. Vmap and functionalize are not addressed in this PR.

Add _disable_inference_mode() to surgically disable inference_mode without clobbering grad_mode/fw_grad_mode (which inference_mode(False) would do, breaking the prev_grad_mode invariant). Wrap grad_increment_nesting, jvp_increment_nesting, and vjp's backward closure.

My hope is that this can unblock inference_mode workflows where one needs to compute gradients of frozen weights, with respect to inputs(e.g., particle positions) and not model parameters. This happens, for example, in molecular dynamics with a neural net potential function; compute the forces for the particle update steps.

Changed files

  • aten/src/ATen/functorch/DynamicLayer.cpp (modified, +7/-5)
  • aten/src/ATen/functorch/DynamicLayer.h (modified, +4/-2)
  • aten/src/ATen/functorch/Interpreter.h (modified, +14/-6)
  • test/functorch/test_eager_transforms.py (modified, +76/-0)
  • torch/_functorch/eager_transforms.py (modified, +33/-6)
  • torch/csrc/functorch/init.cpp (modified, +40/-2)

Code Example

import torch
  from torch.func import jvp

  x = torch.randn(3)
  with torch.inference_mode():
      out, tangent = jvp(lambda x: x**2, (x,), (torch.ones(3),))
      print(tangent)  # tensor([0., 0., 0.]) — should be 2*x
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch.func.{grad, vjp, jvp} silently produce incorrect results (zeros) when called under torch.inference_mode(). No error is raised.

import torch
from torch.func import jvp

x = torch.randn(3)
with torch.inference_mode():
    out, tangent = jvp(lambda x: x**2, (x,), (torch.ones(3),))
    print(tangent)  # tensor([0., 0., 0.]) — should be 2*x

Root cause

Functorch transforms set up their own autograd context on entry: enable_grad() for grad/vjp, _set_fwd_grad_enabled(True) for jvp (see NOTE [grad and vjp interaction with no_grad]). However, inference_mode also excludes autograd dispatch keys from the TLS dispatch key set, and the mode-enabling calls don't undo that. So autograd dispatch keys remain excluded and differentiation silently produces zeros.

The transforms already correctly handle no_grad — this is the same class of issue with a different TLS mechanism.

Expected behavior

Either produce correct results (by also disabling inference_mode on transform entry, matching the existing no_grad handling) or raise an error.

cc @ezyang @gchanan @kadeng @msaroufim @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @Chillee @samdow @kshitij12345

extent analysis

Fix Plan

To fix the issue, we need to modify the torch.func transforms to handle torch.inference_mode() correctly. We can achieve this by disabling inference_mode when entering the transforms, similar to how no_grad is handled.

Modified Code

import torch
from torch.func import jvp

def modified_jvp(func, inputs, tangents):
    with torch.inference_mode(False):  # Disable inference_mode
        with torch.enable_grad():
            _set_fwd_grad_enabled(True)
            out, tangent = jvp(func, inputs, tangents)
    return out, tangent

x = torch.randn(3)
with torch.inference_mode():
    out, tangent = modified_jvp(lambda x: x**2, (x,), (torch.ones(3),))
    print(tangent)  # Should print the correct tangent

Alternatively, we can raise an error when torch.inference_mode() is active:

def modified_jvp(func, inputs, tangents):
    if torch.is_inference_mode():
        raise RuntimeError("jvp is not supported under inference_mode")
    with torch.enable_grad():
        _set_fwd_grad_enabled(True)
        out, tangent = jvp(func, inputs, tangents)
    return out, tangent

Verification

To verify the fix, run the modified code and check that the output is correct. You can also test the error-raising version to ensure it correctly raises an error when torch.inference_mode() is active.

Extra Tips

  • When working with torch.func transforms, be aware of the interaction with torch.inference_mode() and torch.no_grad().
  • Consider adding tests to cover these scenarios to prevent regressions in the future.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Either produce correct results (by also disabling inference_mode on transform entry, matching the existing no_grad handling) or raise an error.

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @Chillee @samdow @kshitij12345

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix torch.func transforms silently return zeros under inference_mode [2 pull requests, 4 comments, 3 participants]