pytorch - 💡(How to fix) Fix Improve the reinplace FX pass to handle in-place mutations on views of graph inputs [1 pull requests]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

When torch.compile compiles a function that does a scatter write (buf_flat[indices] = values) into a view of a graph input, the reinplace pass fails to convert the functional index_put back to an in-place index_put_. This causes Inductor to:

  1. Clone the entire buffer (34 MiB for 4096 pages) into a fresh allocation
  2. Perform the scatter into the clone
  3. Copy the entire buffer back to the original input

The actual scatter only touches ~34 KiB (257 tokens), so the two full-buffer memcpys dominate runtime. This is common in KV-cache update patterns like:

def store(buf, indices, values):
    buf.view(-1)[indices] = values
    return buf

Which produces the functional graph:

%view = reshape(%arg0_1, [-1])
%index_put = index_put(%view, [%arg1_1], %arg2_1)
%view_1 = reshape(%index_put, [1024, 1024])
%copy_ = copy_(%arg0_1, %view_1)
return (copy_,)

Root Cause

Fix Action

Fixed

Code Example

def store(buf, indices, values):
    buf.view(-1)[indices] = values
    return buf

---

%view = reshape(%arg0_1, [-1])
%index_put = index_put(%view, [%arg1_1], %arg2_1)
%view_1 = reshape(%index_put, [1024, 1024])
%copy_ = copy_(%arg0_1, %view_1)
return (copy_,)

---

%view = reshape(%arg0_1, [-1])
%index_put = index_put_(%view, [%arg1_1], %arg2_1)
return (%index_put,)

---

Buffer: 128 MiB, Scatter: 33.0 KiB
  Eager:   0.0042 ms
  Compile: 0.6744 ms
  Ratio:   161.21x

---

import torch


def store_via_view(buf: torch.Tensor, indices: torch.Tensor, values: torch.Tensor):
    buf_flat = buf.view(-1)
    buf_flat[indices] = values
    return buf


def time_cuda(fn, warmup=20, iters=200):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    end.synchronize()
    return start.elapsed_time(end) / iters


def main():
    device = torch.device("cuda")
    num_tokens = 256
    token_bytes = 132
    scatter_bytes = num_tokens * token_bytes
    buf_size = 128 * 1024 * 1024  # 128 MiB

    buf = torch.zeros(1, buf_size, dtype=torch.uint8, device=device)
    indices = torch.randint(0, buf_size, (scatter_bytes,), dtype=torch.long, device=device)
    values = torch.ones(scatter_bytes, dtype=torch.uint8, device=device)

    # Eager
    eager_ms = time_cuda(lambda: store_via_view(buf, indices, values))

    # Compile
    torch._dynamo.reset()
    compiled = torch.compile(store_via_view, fullgraph=True)
    compiled(buf, indices, values)
    torch.cuda.synchronize()
    compile_ms = time_cuda(lambda: compiled(buf, indices, values))

    print(f"Buffer: 128 MiB, Scatter: {scatter_bytes / 1024:.1f} KiB")
    print(f"  Eager:   {eager_ms:.4f} ms")
    print(f"  Compile: {compile_ms:.4f} ms")
    print(f"  Ratio:   {compile_ms / eager_ms:.2f}x")


if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

Summary

When torch.compile compiles a function that does a scatter write (buf_flat[indices] = values) into a view of a graph input, the reinplace pass fails to convert the functional index_put back to an in-place index_put_. This causes Inductor to:

  1. Clone the entire buffer (34 MiB for 4096 pages) into a fresh allocation
  2. Perform the scatter into the clone
  3. Copy the entire buffer back to the original input

The actual scatter only touches ~34 KiB (257 tokens), so the two full-buffer memcpys dominate runtime. This is common in KV-cache update patterns like:

def store(buf, indices, values):
    buf.view(-1)[indices] = values
    return buf

Which produces the functional graph:

%view = reshape(%arg0_1, [-1])
%index_put = index_put(%view, [%arg1_1], %arg2_1)
%view_1 = reshape(%index_put, [1024, 1024])
%copy_ = copy_(%arg0_1, %view_1)
return (copy_,)

root cause

https://github.com/pytorch/pytorch/blob/81fa97792064b7a5cd0f2b48c05a5856ce160a3c/torch/_inductor/fx_passes/reinplace.py#L547-L554 bails out

Expected Behavior

The expected graph will be using the inplace index_put_:

%view = reshape(%arg0_1, [-1])
%index_put = index_put_(%view, [%arg1_1], %arg2_1)
return (%index_put,)

Performance

With a 128 MiB buffer and only 33 KiB of scattered writes, torch.compile is 161× slower than eager execution on RTX 6000

Buffer: 128 MiB, Scatter: 33.0 KiB
  Eager:   0.0042 ms
  Compile: 0.6744 ms
  Ratio:   161.21x
import torch


def store_via_view(buf: torch.Tensor, indices: torch.Tensor, values: torch.Tensor):
    buf_flat = buf.view(-1)
    buf_flat[indices] = values
    return buf


def time_cuda(fn, warmup=20, iters=200):
    for _ in range(warmup):
        fn()
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(iters):
        fn()
    end.record()
    end.synchronize()
    return start.elapsed_time(end) / iters


def main():
    device = torch.device("cuda")
    num_tokens = 256
    token_bytes = 132
    scatter_bytes = num_tokens * token_bytes
    buf_size = 128 * 1024 * 1024  # 128 MiB

    buf = torch.zeros(1, buf_size, dtype=torch.uint8, device=device)
    indices = torch.randint(0, buf_size, (scatter_bytes,), dtype=torch.long, device=device)
    values = torch.ones(scatter_bytes, dtype=torch.uint8, device=device)

    # Eager
    eager_ms = time_cuda(lambda: store_via_view(buf, indices, values))

    # Compile
    torch._dynamo.reset()
    compiled = torch.compile(store_via_view, fullgraph=True)
    compiled(buf, indices, values)
    torch.cuda.synchronize()
    compile_ms = time_cuda(lambda: compiled(buf, indices, values))

    print(f"Buffer: 128 MiB, Scatter: {scatter_bytes / 1024:.1f} KiB")
    print(f"  Eager:   {eager_ms:.4f} ms")
    print(f"  Compile: {compile_ms:.4f} ms")
    print(f"  Ratio:   {compile_ms / eager_ms:.2f}x")


if __name__ == "__main__":
    main()

cc: @kshitij12345

Alternatives

No response

Additional context

No response

cc @jerryzh168 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @aakhundov @coconutruben @jataylo @bdhirsh @bobrenjc93 @aorenste

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING