pytorch - 💡(How to fix) Fix `torch.compile` silently produces all-zero `x.grad` for `diagonal_scatter(...).sum().backward()` [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#182012Fetched 2026-05-01 05:32:50
View on GitHub
Comments
0
Participants
1
Timeline
49
Reactions
0
Author
Participants
Timeline (top)
subscribed ×22mentioned ×21labeled ×6

Root Cause

The backward FX graph for diagonal_scatter(x, src).sum() is:

expand(tangent_scalar, [N, N])  →  clone  →  diagonal_scatter(clone, zeros)  →  return

Inductor lowers clone(expand(scalar, [N, N])) into a Triton kernel that writes a single element to buf0, then returns reinterpret_tensor(buf0, (N, N), (0, 0), 0) — a stride-(0, 0) view where all N² positions alias one memory cell. The subsequent diagonal_scatter kernel writes 0.0 to that single cell (to zero the diagonal), so the entire gradient reads back as 0.

Correct behavior: clone must materialize the expanded tensor into a contiguous buffer so that diagonal_scatter can selectively zero only the diagonal positions.

This is not the same root cause as #180164 / #180771 (slice_scatter backward), which was in AOTAutograd's meta_slice_scatter preserving zero strides. Here aot_eager is correct; only Inductor's codegen is wrong.

Code Example

import torch

x = torch.randn(8, 8, device='cuda', requires_grad=True)
src = torch.randn(8, device='cuda', requires_grad=True)

def fn(x, src):
    return torch.diagonal_scatter(x, src, 0).sum()

# Eager: correct
x1 = x.clone().detach().requires_grad_(True)
s1 = src.clone().detach().requires_grad_(True)
fn(x1, s1).backward()
print(f"eager x.grad:\n{x1.grad}")
# tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
#         [1., 0., 1., 1., 1., 1., 1., 1.],
#         ...])  # 1 everywhere except diagonal

# Compiled: wrong
torch._dynamo.reset()
x2 = x.clone().detach().requires_grad_(True)
s2 = src.clone().detach().requires_grad_(True)
torch.compile(fn)(x2, s2).backward()
print(f"compiled x.grad:\n{x2.grad}")
# tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0., 0., 0., 0.],
#         ...])  # ALL ZEROS — incorrect

---

expand(tangent_scalar, [N, N])  →  clone  →  diagonal_scatter(clone, zeros)return
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

🐛 Describe the bug

torch.compile(backend='inductor') silently produces x.grad = 0 for torch.diagonal_scatter(x, src, offset).sum().backward(), while eager mode correctly produces x.grad = 1 at all non-diagonal positions (and 0 on the diagonal, since those positions were overwritten by src).

The forward output is correct; only the backward gradient for x is wrong. src.grad is computed correctly. aot_eager also produces the correct gradient, confirming the bug is in Inductor's code generation, not in AOTAutograd or the decomposition.

Minimal reproducer

import torch

x = torch.randn(8, 8, device='cuda', requires_grad=True)
src = torch.randn(8, device='cuda', requires_grad=True)

def fn(x, src):
    return torch.diagonal_scatter(x, src, 0).sum()

# Eager: correct
x1 = x.clone().detach().requires_grad_(True)
s1 = src.clone().detach().requires_grad_(True)
fn(x1, s1).backward()
print(f"eager x.grad:\n{x1.grad}")
# tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
#         [1., 0., 1., 1., 1., 1., 1., 1.],
#         ...])  # 1 everywhere except diagonal

# Compiled: wrong
torch._dynamo.reset()
x2 = x.clone().detach().requires_grad_(True)
s2 = src.clone().detach().requires_grad_(True)
torch.compile(fn)(x2, s2).backward()
print(f"compiled x.grad:\n{x2.grad}")
# tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0., 0., 0., 0.],
#         ...])  # ALL ZEROS — incorrect

Root cause analysis

The backward FX graph for diagonal_scatter(x, src).sum() is:

expand(tangent_scalar, [N, N])  →  clone  →  diagonal_scatter(clone, zeros)  →  return

Inductor lowers clone(expand(scalar, [N, N])) into a Triton kernel that writes a single element to buf0, then returns reinterpret_tensor(buf0, (N, N), (0, 0), 0) — a stride-(0, 0) view where all N² positions alias one memory cell. The subsequent diagonal_scatter kernel writes 0.0 to that single cell (to zero the diagonal), so the entire gradient reads back as 0.

Correct behavior: clone must materialize the expanded tensor into a contiguous buffer so that diagonal_scatter can selectively zero only the diagonal positions.

This is not the same root cause as #180164 / #180771 (slice_scatter backward), which was in AOTAutograd's meta_slice_scatter preserving zero strides. Here aot_eager is correct; only Inductor's codegen is wrong.

Affected configurations

  • Shapes: all tested sizes (4×4 through 256×256), square and non-square, all diagonal offsets
  • Dtypes: float32 and float64
  • Devices: CUDA (tested on Tesla T4)
  • Trigger pattern: diagonal_scatter(x, src, offset) followed directly by .sum() (any dim variant)
  • Adding a pointwise op between diagonal_scatter and sum (e.g. .relu(), .mul(2)) avoids the bug by preventing the problematic fusion

Characterization

BUG (x.grad = all zeros) — ops that preserve the stride-0 expand in backward:

Patternx.grad
.sum()all zeros
.sum(dim=0), .sum(dim=1), .sum(dim=(0,1))all zeros
.add(1).sum()all zeros
.clone().sum()all zeros
.contiguous().sum()all zeros
.view(64).sum(), .reshape(4,16).sum(), .flatten().sum()all zeros
.t().sum(), .transpose(0,1).sum()all zeros

Correct — ops that force materialization of the expanded tensor:

Patternx.grad
.mean()correct
.relu().sum()correct
.mul(2).sum(), .neg().sum()correct
.abs().sum(), .square().sum()correct
.exp().sum(), .sigmoid().sum(), .tanh().sum()correct
@ w .sum() (matmul)correct
.amax(), .norm(), .logsumexp(dim=-1).sum()correct

The dividing line: ops whose backward produces a uniform gradient (all elements equal) get fused into a stride-0 buffer. Non-trivial pointwise ops or non-uniform reductions produce element-dependent gradients that force Inductor to materialize the full buffer.

Versions

Versions

  • PyTorch: 2.13.0.dev20260429+cu126
  • Triton: 3.2.0+git4b3bb1e8
  • CUDA: 12.6
  • GPU: Tesla T4 (sm_75)
  • Python: 3.11

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The issue can be worked around by adding a pointwise operation between diagonal_scatter and sum to prevent the problematic fusion.

Guidance

  • Identify instances of diagonal_scatter followed directly by sum and add a pointwise operation, such as .relu() or .mul(2), to prevent the bug.
  • Verify that the added operation does not affect the correctness of the forward pass.
  • Test the workaround with different shapes, dtypes, and devices to ensure it is effective.
  • Consider using aot_eager instead of Inductor as a temporary workaround, as it produces the correct gradient.

Example

import torch

x = torch.randn(8, 8, device='cuda', requires_grad=True)
src = torch.randn(8, device='cuda', requires_grad=True)

def fn(x, src):
    return torch.diagonal_scatter(x, src, 0).relu().sum()  # Add .relu() to prevent the bug

# Compiled: correct
torch._dynamo.reset()
x2 = x.clone().detach().requires_grad_(True)
s2 = src.clone().detach().requires_grad_(True)
torch.compile(fn)(x2, s2).backward()
print(f"compiled x.grad:\n{x2.grad}")

Notes

The workaround may not be applicable in all cases, and the root cause of the issue is still in Inductor's code generation. Further investigation is needed to fix the issue permanently.

Recommendation

Apply the workaround by adding a pointwise operation between diagonal_scatter and sum, as it is a safe and effective way to prevent the bug.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix `torch.compile` silently produces all-zero `x.grad` for `diagonal_scatter(...).sum().backward()` [1 participants]