pytorch - ✅(Solved) Fix `reinplace_inplaceable_ops` fails with `diagonal_scatter` on expanded (stride-zero) tensor in backward graph [1 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#178995Fetched 2026-04-08 02:24:40
View on GitHub
Comments
1
Participants
2
Timeline
59
Reactions
0
Author
Participants
Timeline (top)
mentioned ×25subscribed ×25labeled ×5cross-referenced ×2

_decompose_scatter_mutating in reinplace_inplaceable_ops fails when the input tensor to a scatter op has zero strides (from expand). The decomposition clones with clone_preserve_strides, preserving the aliased memory layout, then attempts copy_ on a view of the aliased clone, which fails.

  RuntimeError: more than one element of the written-to tensor refers to a single memory location

Error Message

RuntimeError: more than one element of the written-to tensor refers to a single memory location

Root Cause

AOTAutograd's backward can produce valid graphs where expand with stride [0, 0] feeds into diagonal_scatter (e.g., from the backward of diagonal_scatter in the forward pass). This pattern appears when diagonal_scatter is used on a small tensor derived from a scalar-like gradient.

In _decompose_scatter_mutating (torch/_inductor/fx_passes/reinplace.py), the decomposition does:

  inp = graph_call_function(graph, aten.clone, inp)   # preserves stride [0, 0]
  tmp = graph_call_function(graph, view.target, tmp, ...)  # diagonal view of aliased memory
  graph_call_function(graph, aten.copy_.default, tmp, src)  # FAILS: aliased memory

graph_call_function runs the ops on fake tensors to compute metadata. aten.clone on a stride-[0, 0] tensor preserves the zero strides (via PyTorch's stride-preserving clone semantics). The subsequent aten.diagonal creates a view where multiple elements point to the same location. aten.copy_ then rejects this.

Fix Action

Fixed

PR fix notes

PR #407: Fix reinplace_inplaceable_ops crash on scatter ops with stride-zero inputs

Description (problem / solution / changelog)

When the parallel graph contains scatter ops (diagonal_scatter, select_scatter, slice_scatter, as_strided_scatter) whose input has zero strides (from expand), inductor's reinplace_inplaceable_ops pass fails with RuntimeError: more than one element of the written-to tensor refers to a single memory location.

This happens because the backward decomposition of diagonal_scatter can produce expand with stride [0, 0] feeding into diagonal_scatter. The reinplace pass tries to decompose this into a clone + diagonal view + copy_, but the stride-preserving clone keeps the aliased memory, and copy_ on an aliased view is invalid.

This is a PyTorch bug (reported upstream in https://github.com/pytorch/pytorch/issues/178995) — the same crash affects FSDP2 + torch.compile with per-block activation checkpointing on models that use diagonal_scatter. We work around it by inserting a clone before any scatter op whose input has zero strides, materializing the aliased view into contiguous memory before inductor sees it.

Authored with Claude.

Changed files

  • autoparallel/api.py (modified, +2/-0)
  • autoparallel/graph_passes/graph_utils.py (modified, +32/-0)

Code Example

RuntimeError: more than one element of the written-to tensor refers to a single memory location

---

inp = graph_call_function(graph, aten.clone, inp)   # preserves stride [0, 0]
  tmp = graph_call_function(graph, view.target, tmp, ...)  # diagonal view of aliased memory
  graph_call_function(graph, aten.copy_.default, tmp, src)  # FAILS: aliased memory

---

# Current:
  inp = graph_call_function(graph, aten.clone, inp)

  # Fix:
  if any(s == 0 for s in inp.meta["val"].stride()):
      inp = graph_call_function(graph, aten.clone, inp, memory_format=torch.contiguous_format)
  else:
      inp = graph_call_function(graph, aten.clone, inp)
RAW_BUFFERClick to expand / collapse

🐛 Describe the bug

Summary

_decompose_scatter_mutating in reinplace_inplaceable_ops fails when the input tensor to a scatter op has zero strides (from expand). The decomposition clones with clone_preserve_strides, preserving the aliased memory layout, then attempts copy_ on a view of the aliased clone, which fails.

  RuntimeError: more than one element of the written-to tensor refers to a single memory location

Root Cause

AOTAutograd's backward can produce valid graphs where expand with stride [0, 0] feeds into diagonal_scatter (e.g., from the backward of diagonal_scatter in the forward pass). This pattern appears when diagonal_scatter is used on a small tensor derived from a scalar-like gradient.

In _decompose_scatter_mutating (torch/_inductor/fx_passes/reinplace.py), the decomposition does:

  inp = graph_call_function(graph, aten.clone, inp)   # preserves stride [0, 0]
  tmp = graph_call_function(graph, view.target, tmp, ...)  # diagonal view of aliased memory
  graph_call_function(graph, aten.copy_.default, tmp, src)  # FAILS: aliased memory

graph_call_function runs the ops on fake tensors to compute metadata. aten.clone on a stride-[0, 0] tensor preserves the zero strides (via PyTorch's stride-preserving clone semantics). The subsequent aten.diagonal creates a view where multiple elements point to the same location. aten.copy_ then rejects this.

Suggested Fix

In _decompose_scatter_mutating, force the clone to be contiguous when the input has zero strides:

  # Current:
  inp = graph_call_function(graph, aten.clone, inp)

  # Fix:
  if any(s == 0 for s in inp.meta["val"].stride()):
      inp = graph_call_function(graph, aten.clone, inp, memory_format=torch.contiguous_format)
  else:
      inp = graph_call_function(graph, aten.clone, inp)

This materializes the aliased view into contiguous memory before the mutation, which is both correct and efficient (the clone is already being done — just with different stride semantics).

Affected Ops

Any scatter op in _SCATTER_OP_TO_VIEW with a stride-zero input:

  • aten.diagonal_scatter
  • aten.select_scatter
  • aten.slice_scatter
  • aten.as_strided_scatter

Repro

The attached fx_graph_runnable_18.txt (generated by tlparse) reproduces the issue directly: python fx_graph_runnable_18.txt

This is a backward graph produced by aot_export_joint_with_descriptors for a model using torch.diagonal_scatter with ignore_diagonal=True in its loss function. The backward decomposition produces expand(scalar_grad, [2, 2]) with stride [0, 0], followed by diagonal_scatter.

torch.compile's single-shot path avoids this because functionalization ordering prevents the pattern from reaching reinplace. The two-stage path (aot_export → transform → aot_compile, used by AutoParallel) exposes it.

fx_graph_runnable_18.txt

Versions

PyTorch nightly from 2.12.0.dev20260323+cu128

cc @ezyang @gchanan @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @muchulee8 @amjames @aakhundov @coconutruben @jataylo

extent analysis

TL;DR

The most likely fix is to modify the _decompose_scatter_mutating function to force the clone to be contiguous when the input has zero strides.

Guidance

  • Identify the input tensors with zero strides in the _decompose_scatter_mutating function and apply the suggested fix to ensure the clone is contiguous.
  • Verify that the fix resolves the RuntimeError: more than one element of the written-to tensor refers to a single memory location issue.
  • Be aware that this fix may impact performance, as it materializes the aliased view into contiguous memory.
  • Test the fix with the provided reproducible example fx_graph_runnable_18.txt to ensure it resolves the issue.

Example

if any(s == 0 for s in inp.meta["val"].stride()):
    inp = graph_call_function(graph, aten.clone, inp, memory_format=torch.contiguous_format)
else:
    inp = graph_call_function(graph, aten.clone, inp)

Notes

This fix is specific to the PyTorch nightly version from 2.12.0.dev20260323+cu128 and may not apply to other versions.

Recommendation

Apply the workaround by modifying the _decompose_scatter_mutating function to force the clone to be contiguous when the input has zero strides, as this resolves the RuntimeError issue and ensures correct functionality.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING