pytorch - 💡(How to fix) Fix [MPS] Add graph capture/replay API to eliminate per-op CPU dispatch overhead [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#180397Fetched 2026-04-16 06:34:50
View on GitHub
Comments
0
Participants
1
Timeline
29
Reactions
0
Author
Participants
Timeline (top)
mentioned ×12subscribed ×12labeled ×5

Error Message

  • Buffer size validation on replay: asserts buffer sizes match what was recorded, catching tensor reallocation with a clear error instead of silent incorrectness

Fix Action

Fix / Workaround

MPS currently has no equivalent of CUDA Graphs. Every PyTorch op on MPS goes through Python dispatch, MPSGraph compilation/validation, cache lookup, and command buffer encoding on each call. For dispatch-heavy models (deep transformers, small LLMs, activation chains), this CPU overhead is significant relative to actual GPU compute time.

I propose adding a graph capture/replay API to the MPS backend exposed through torch.mps.graph_capture() and torch.mps.graph_replay() that records MPS operations on the first pass and replays them in a single dispatch on subsequent passes, skipping all per-op CPU overhead.

Replay: re-encode all ops in a single dispatch, no Python/compilation overhead

for batch_data in loader: x.copy_(batch_data) # update input in-place torch.mps.graph_replay() # replay recorded ops results.append(out.cpu())

Code Example

model.eval()
x = torch.randn(batch, seq, d_model, device="mps")

# Capture pass: ops run normally and get recorded
with torch.no_grad():
    with torch.mps.graph_capture():
        out = model(x)

# Replay: re-encode all ops in a single dispatch, no Python/compilation overhead
for batch_data in loader:
    x.copy_(batch_data)        # update input in-place
    torch.mps.graph_replay()   # replay recorded ops
    results.append(out.cpu())

---

g = torch.mps.MPSGraph()
with torch.mps.graph(g):
    out = model(x)

x.copy_(new_input)
g.replay()
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

MPS currently has no equivalent of CUDA Graphs. Every PyTorch op on MPS goes through Python dispatch, MPSGraph compilation/validation, cache lookup, and command buffer encoding on each call. For dispatch-heavy models (deep transformers, small LLMs, activation chains), this CPU overhead is significant relative to actual GPU compute time.

I propose adding a graph capture/replay API to the MPS backend exposed through torch.mps.graph_capture() and torch.mps.graph_replay() that records MPS operations on the first pass and replays them in a single dispatch on subsequent passes, skipping all per-op CPU overhead.

The proposed API:

model.eval()
x = torch.randn(batch, seq, d_model, device="mps")

# Capture pass: ops run normally and get recorded
with torch.no_grad():
    with torch.mps.graph_capture():
        out = model(x)

# Replay: re-encode all ops in a single dispatch, no Python/compilation overhead
for batch_data in loader:
    x.copy_(batch_data)        # update input in-place
    torch.mps.graph_replay()   # replay recorded ops
    results.append(out.cpu())

Also available as a class-based API mirroring torch.cuda.CUDAGraph:

g = torch.mps.MPSGraph()
with torch.mps.graph(g):
    out = model(x)

x.copy_(new_input)
g.replay()

The implementation captures both MPSGraph ops (matmul, linear, etc.) and raw Metal kernel dispatches (elementwise, activations, etc.) through a centralized recording mechanism at MPSStream::commandEncoder(). A wrapper object MPSRecordingEncoder intercepts the 5 key Metal encoder methods (setComputePipelineState:, setBuffer:offset:atIndex:, setBytes:length:atIndex:, dispatchThreads:threadsPerThreadgroup:, dispatchThreadgroups:threadsPerThreadgroup:) and records kernel state while forwarding calls to the underlying encoder. This ensures all Metal kernels are captured automatically with no per-site recording code.

Additionally includes:

  • Executable caching: compiled MPSGraphExecutable objects cached per graph to skip re-validation
  • Buffer size validation on replay: asserts buffer sizes match what was recorded, catching tensor reallocation with a clear error instead of silent incorrectness

Constraints:

  • Tensor shapes must not change between capture and replays
  • Input data must be updated in-place via .copy_() before each replay
  • MPS profiling must be disabled during capture

Using a 100-config benchmark suite (LLMs, diffusion, dispatch-heavy models) on Apple M4 Pro, 10 trials, 5 warmup, 10% trimmed mean:

CategoryImprovementExamples
Dispatch-heavy (deep/tiny LLMs)+26-32%tiny-llm-128-L64: +32%, tiny-llm-256-L128: +28%
Small LLMs+4%qwen3-0.5b: +4.4%
Diffusion (Flux MM-DiT)+1-5%flux-s64: +5.2%
GPU-bound (7B+, large batch)~0%expected, dispatch overhead is negligible

Zero regressions across all 100 configs. Gains are concentrated on dispatch-heavy models where CPU overhead dominates GPU compute time.

Alternatives

Relying on eager execution without capture. For large GPU-bound models the dispatch overhead is negligible, but for dispatch-heavy workloads (small models, deep stacks, inference serving with many small batches) the CPU overhead adds up significantly.

Additional context

No response

cc @jhavukainen @jerryzh168 @kulinseth @malfet @DenisVieriu97 @aditvenk

extent analysis

TL;DR

To reduce CPU overhead in dispatch-heavy models on MPS, utilize the proposed torch.mps.graph_capture() and torch.mps.graph_replay() API to record and replay MPS operations.

Guidance

  • Implement the torch.mps.graph_capture() context manager to record MPS operations on the first pass.
  • Use torch.mps.graph_replay() to replay the recorded operations in a single dispatch on subsequent passes, skipping per-op CPU overhead.
  • Ensure tensor shapes remain unchanged between capture and replays, and update input data in-place using .copy_() before each replay.
  • Disable MPS profiling during capture to avoid interference with the graph capture mechanism.

Example

model.eval()
x = torch.randn(batch, seq, d_model, device="mps")

with torch.no_grad():
    with torch.mps.graph_capture():
        out = model(x)

for batch_data in loader:
    x.copy_(batch_data)
    torch.mps.graph_replay()
    results.append(out.cpu())

Notes

The proposed API is designed to work with dispatch-heavy models, where CPU overhead dominates GPU compute time. For GPU-bound models, the dispatch overhead is negligible, and this optimization may not provide significant gains.

Recommendation

Apply the proposed torch.mps.graph_capture() and torch.mps.graph_replay() API to reduce CPU overhead in dispatch-heavy models, as it provides significant performance improvements (up to 32%) without introducing regressions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - 💡(How to fix) Fix [MPS] Add graph capture/replay API to eliminate per-op CPU dispatch overhead [1 participants]