pytorch - 💡(How to fix) Fix `torch._higher_order_ops.scan` computes incorrect gradients for closed-over parameters [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Root Cause

scan_autograd builds the per-step joint graph from the real initial carry tensors. If the initial carry does not require grad, AOTAutograd traces the single-step backward with a zero carry gradient. The reverse scan then carries zero cotangents across timesteps, so closed-over parameters miss the indirect recurrent contribution through the carry state.

The carry cotangent is internal recurrence state and is needed to propagate gradients through time even when the user does not request a gradient for the initial carry itself.

Fix Action

Fixed

Code Example

from __future__ import annotations

import torch
from torch._higher_order_ops import scan


def loop_recurrence(x: torch.Tensor, a: torch.Tensor, carry0: torch.Tensor):
    carry = carry0
    ys = []
    for x_t in x:
        carry = a * carry + x_t
        ys.append(carry)
    return carry, torch.stack(ys)


def scan_recurrence(x: torch.Tensor, a: torch.Tensor, carry0: torch.Tensor):
    def step(carry: torch.Tensor, x_t: torch.Tensor):
        carry_next = a * carry + x_t
        return carry_next.clone(), carry_next.clone()

    return scan(step, carry0, x, dim=0)


def value_and_grad(recurrence, include_final_carry: bool):
    dtype = torch.float64
    x = torch.tensor([0.11, -0.37, 0.23, 0.41, -0.19], dtype=dtype)
    a = torch.tensor(0.7, dtype=dtype, requires_grad=True)
    carry0 = torch.tensor(0.2, dtype=dtype)

    carry, ys = recurrence(x, a, carry0)
    loss = ys.square().sum()
    if include_final_carry:
        loss = loss + carry.square()
    (grad_a,) = torch.autograd.grad(loss, a)
    return loss.detach(), grad_a.detach(), carry.detach(), ys.detach()


for include_final_carry in (False, True):
    scan_loss, scan_grad, scan_carry, scan_ys = value_and_grad(
        scan_recurrence, include_final_carry
    )
    loop_loss, loop_grad, loop_carry, loop_ys = value_and_grad(
        loop_recurrence, include_final_carry
    )

    print(f"include_final_carry={include_final_carry}")
    print(f"torch={torch.__version__}")
    print(f"forward loss: scan={scan_loss.item():.17g} loop={loop_loss.item():.17g}")
    print(f"grad dloss/da: scan={scan_grad.item():.17g} loop={loop_grad.item():.17g}")

    torch.testing.assert_close(scan_carry, loop_carry)
    torch.testing.assert_close(scan_ys, loop_ys)
    torch.testing.assert_close(scan_loss, loop_loss)
    torch.testing.assert_close(scan_grad, loop_grad)

---

include_final_carry=False
forward loss: scan=0.35571607672499994 loop=0.35571607672499994
grad dloss/da: scan=0.19074693349999999 loop=0.26872833559999998
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

torch._higher_order_ops.scan can produce correct forward values but incorrect gradients for closed-over parameters when the scan carry does not require grad.

This affects recurrent code that closes over trainable weights in the scan body. For example, a scan-based RNN/LSTM can match a Python loop in the forward pass while silently training with different parameter gradients.

Reproducer

from __future__ import annotations

import torch
from torch._higher_order_ops import scan


def loop_recurrence(x: torch.Tensor, a: torch.Tensor, carry0: torch.Tensor):
    carry = carry0
    ys = []
    for x_t in x:
        carry = a * carry + x_t
        ys.append(carry)
    return carry, torch.stack(ys)


def scan_recurrence(x: torch.Tensor, a: torch.Tensor, carry0: torch.Tensor):
    def step(carry: torch.Tensor, x_t: torch.Tensor):
        carry_next = a * carry + x_t
        return carry_next.clone(), carry_next.clone()

    return scan(step, carry0, x, dim=0)


def value_and_grad(recurrence, include_final_carry: bool):
    dtype = torch.float64
    x = torch.tensor([0.11, -0.37, 0.23, 0.41, -0.19], dtype=dtype)
    a = torch.tensor(0.7, dtype=dtype, requires_grad=True)
    carry0 = torch.tensor(0.2, dtype=dtype)

    carry, ys = recurrence(x, a, carry0)
    loss = ys.square().sum()
    if include_final_carry:
        loss = loss + carry.square()
    (grad_a,) = torch.autograd.grad(loss, a)
    return loss.detach(), grad_a.detach(), carry.detach(), ys.detach()


for include_final_carry in (False, True):
    scan_loss, scan_grad, scan_carry, scan_ys = value_and_grad(
        scan_recurrence, include_final_carry
    )
    loop_loss, loop_grad, loop_carry, loop_ys = value_and_grad(
        loop_recurrence, include_final_carry
    )

    print(f"include_final_carry={include_final_carry}")
    print(f"torch={torch.__version__}")
    print(f"forward loss: scan={scan_loss.item():.17g} loop={loop_loss.item():.17g}")
    print(f"grad dloss/da: scan={scan_grad.item():.17g} loop={loop_grad.item():.17g}")

    torch.testing.assert_close(scan_carry, loop_carry)
    torch.testing.assert_close(scan_ys, loop_ys)
    torch.testing.assert_close(scan_loss, loop_loss)
    torch.testing.assert_close(scan_grad, loop_grad)

Expected Behavior

The scan recurrence and Python-loop recurrence should produce matching forward values, losses, and gradients with respect to the closed-over parameter a.

Actual Behavior

Forward values and losses match, but dloss / da differs.

Example output:

include_final_carry=False
forward loss: scan=0.35571607672499994 loop=0.35571607672499994
grad dloss/da: scan=0.19074693349999999 loop=0.26872833559999998

With include_final_carry=True, the gradient also differs.

Root Cause

scan_autograd builds the per-step joint graph from the real initial carry tensors. If the initial carry does not require grad, AOTAutograd traces the single-step backward with a zero carry gradient. The reverse scan then carries zero cotangents across timesteps, so closed-over parameters miss the indirect recurrent contribution through the carry state.

The carry cotangent is internal recurrence state and is needed to propagate gradients through time even when the user does not request a gradient for the initial carry itself.

Versions

Collecting environment information... PyTorch version: 2.13.0a0+git65e408e Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 26.3 (arm64) GCC version: Could not collect Clang version: 21.0.0 (clang-2100.0.123.102) CMake version: version 4.3.2 Libc version: N/A

Python version: 3.12.10 (main, May 22 2025, 01:38:44) [Clang 20.1.4 ] (64-bit runtime) Python platform: macOS-26.3-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA Is XPU available: False HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True Caching allocator config: N/A

CPU: Apple M4 Pro

Versions of relevant libraries: [pip3] numpy==2.4.6 [pip3] optree==0.19.1 [pip3] torch==2.13.0a0+git65e408e [conda] Could not collect

cc @ezyang @albanD @gqchen @nikitaved @soulitzer @Varal7 @bobrenjc93 @chauhang @penguinwu @ydwu4 @bdhirsh @aorenste

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch._higher_order_ops.scan` computes incorrect gradients for closed-over parameters [1 pull requests]