pytorch - ✅(Solved) Fix [RFC] Make torch.library.custom_op faster in the common case [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177109Fetched 2026-04-08 00:22:15
View on GitHub
Comments
2
Participants
2
Timeline
31
Reactions
0
Author
Participants
Timeline (top)
subscribed ×10mentioned ×9labeled ×7commented ×2

Root Cause

This seems fine, because whatever we've done, we have made torch.library.custom_op faster overall. I'm too scared to touch custom operators created by torch.library.define/impl because I'm worried users are sensitive to this, but in theory we could change those as well.

Fix Action

Fix / Workaround

At the start of invoking a python torch.library.custom_op (where all the implementations are written in Python), we are going to perform one dispatch key computation. This one computes the dispatch key set which will tell us if we're looking at "normal eager mode".

"normal eager mode" is a dispatch keyset that looks like:

  • plain cpu tensor: [CPU, ADInplaceOrView, AutogradCPU, AutocastCPU]
  • plain cuda tensor: [CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA]
  • plain cpu tensor variant 2: [CPU, AutocastCPU] (happens with inference_mode)
  • plain cuda tensor variant 2: [CUDA, AutocastCUDA] (happens with inference_mode)

This dispatch tells if we're in the fast path or the slow path.

  • the slow path is: we do the standard dispatch of the custom operator again. This is 0.7us slower due to the one dispatch key computation we just did taking 0.7us.
  • the fast path is: we're going to dispatch the entirety of the custom operator in Python. So we avoid crossing into C++, unless the kernel actually needs to cross into C++. This avoid us doing things like Python -> C++ (dispatch) -> Python (autograd kernel) -> C++ (dispatch) -> Python (backend kernel).

PR fix notes

PR #178216: Speed up torch.library.custom_op common eager calls

Description (problem / solution / changelog)

Addresses #177109. To reduce overhead of micro custom_op kernel calls, this PR adds a Python fast path for common eager cases.

Safety checks are still preserved. The implementation keeps aliasing checks, autograd handling, mutable-op version-counter handling, and dispatcher fallback for cases like autocast, TorchFunctionMode, tensor subclasses, tensor-list schemas, meta tensors, view ops, and torch.compile.

How this maps to #177109

Achieved:

  • Common eager custom_op calls are much faster
  • No-grad custom_op is now essentially in the same range as define/impl
  • The eager inference_mode variant is handled on the fast path too

Not yet:

  • The one-dispatch-key-computation design is not implemented yet. This PR relies on sequential if guards
  • it does not reach the lower prototype-style ~3.4 us result

Benchmarks

Stabilized run on ARM aarch64 (GB200 / Neoverse V2), using blocked_autorange(min_run_time=5.0). The branch point is 928cada7041ff3621600d35912fba27d69958432.

ScenarioBranch pointCurrent
custom_op (no grad)21.0 us7.8 us
custom_op via ops.*14.6 us8.0 us
custom_op (grad)29.9 us23.4 us
custom_mutate_op (no grad)19.4 us9.7 us
custom_op (inference_mode)8.6 us5.7 us
define/impl7.3 us7.7 us
direct call1.1 us1.1 us

Test plan

  • python test/test_custom_ops.py TestCustomOpFastPath -v
  • python test/test_custom_ops.py TestCustomOpAPI -v

Benchmark script

<details>
import torch
from torch import Tensor
import torch.utils.benchmark as benchmark


MIN_RUN_TIME_S = 5.0


@torch.library.custom_op("benchlib::clone_custom_op", mutates_args=())
def clone_custom_op(
    a: Tensor, b: Tensor, c: Tensor, d: Tensor, e: Tensor,
    f: Tensor, g: Tensor, h: Tensor, i: Tensor, j: Tensor,
    k: Tensor, l: Tensor, m: Tensor, n: Tensor,
) -> Tensor:
    return a.clone()


@torch.library.custom_op("benchlib::mutate_op", mutates_args=("a",))
def mutate_op(
    a: Tensor, b: Tensor, c: Tensor, d: Tensor, e: Tensor,
    f: Tensor, g: Tensor, h: Tensor, i: Tensor, j: Tensor,
    k: Tensor, l: Tensor, m: Tensor, n: Tensor,
) -> Tensor:
    a.zero_()
    return a.clone()


lib = torch.library.Library("benchlib2", "DEF")
lib.define(
    "clone_define_impl(Tensor a, Tensor b, Tensor c, Tensor d, Tensor e,"
    " Tensor f, Tensor g, Tensor h, Tensor i, Tensor j,"
    " Tensor k, Tensor l, Tensor m, Tensor n) -> Tensor"
)
lib.impl("clone_define_impl", lambda a, *_: a.clone(), "CPU")
lib.impl("clone_define_impl", lambda a, *_: a.clone(), "CUDA")


def clone_direct(a, b, c, d, e, f, g, h, i, j, k, l, m, n):
    return a.clone()


def bench(fn, args):
    t = benchmark.Timer(stmt="fn(*args)", globals={"fn": fn, "args": args})
    m = t.blocked_autorange(min_run_time=MIN_RUN_TIME_S)
    return m.median * 1e6


def bench_inference_mode(fn):
    t = benchmark.Timer(
        stmt="fn(*args)",
        globals={"fn": fn},
        setup=(
            "ctx = torch.inference_mode(); ctx.__enter__(); "
            "args = [torch.randn(4) for _ in range(14)]"
        ),
    )
    m = t.blocked_autorange(min_run_time=MIN_RUN_TIME_S)
    return m.median * 1e6


def main():
    tensors = [torch.randn(4) for _ in range(14)]
    grad_tensors = [torch.randn(4, requires_grad=True) for _ in range(14)]
    op_define = torch.ops.benchlib2.clone_define_impl
    op_custom_via_ops = torch.ops.benchlib.clone_custom_op

    print(f"PyTorch {torch.__version__}")
    print(f"{'Scenario':30s} {'us':>8s}")
    print("-" * 40)
    for label, fn, args in [
        ("custom_op (no grad)", clone_custom_op, tensors),
        ("custom_op via ops.*", op_custom_via_ops, tensors),
        ("custom_op (grad)", clone_custom_op, grad_tensors),
        ("mutate_op (no grad)", mutate_op, tensors),
        ("define/impl", op_define, tensors),
        ("direct call", clone_direct, tensors),
    ]:
        us = bench(fn, args)
        print(f"{label:30s} {us:8.1f}")

    us = bench_inference_mode(clone_custom_op)
    print(f"{'custom_op (inference_mode)':30s} {us:8.1f}")


if __name__ == "__main__":
    main()
</details>

Changed files

  • test/test_custom_ops.py (modified, +388/-0)
  • torch/_C/__init__.pyi.in (modified, +3/-0)
  • torch/_library/custom_ops.py (modified, +144/-0)
  • torch/csrc/autograd/init.cpp (modified, +92/-0)
RAW_BUFFERClick to expand / collapse

I benchmarked a custom operator that accepts 14 inputs and clones one of its inputs. The numbers today (for one call, on my machine) are:

Numbers:

  • 10.7us (torch.library.custom_op)
  • 6.3us (torch.library.define / torch.library.impl)
  • 1.4us (no custom operator - call the function directly)

Pitch:

  • First, we're going to make torch.library.custom_op faster (to 6.3us) by deleting all of the safety checks (that's the main delta from torch.library.define/impl).
  • Next, we're going to make the common case of torch.library.custom_op (plain Tensors in eager-mode PyTorch, no subclasses) faster. I have a claude prototype that brings this down to 3.4us (see next section)
  • The implication of this is that we make the uncommon case of torch.library.custom_op (everything else not mentioned) to around 7us.

This seems fine, because whatever we've done, we have made torch.library.custom_op faster overall. I'm too scared to touch custom operators created by torch.library.define/impl because I'm worried users are sensitive to this, but in theory we could change those as well.

Design

At the start of invoking a python torch.library.custom_op (where all the implementations are written in Python), we are going to perform one dispatch key computation. This one computes the dispatch key set which will tell us if we're looking at "normal eager mode".

"normal eager mode" is a dispatch keyset that looks like:

  • plain cpu tensor: [CPU, ADInplaceOrView, AutogradCPU, AutocastCPU]
  • plain cuda tensor: [CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA]
  • plain cpu tensor variant 2: [CPU, AutocastCPU] (happens with inference_mode)
  • plain cuda tensor variant 2: [CUDA, AutocastCUDA] (happens with inference_mode)

This dispatch tells if we're in the fast path or the slow path.

  • the slow path is: we do the standard dispatch of the custom operator again. This is 0.7us slower due to the one dispatch key computation we just did taking 0.7us.
  • the fast path is: we're going to dispatch the entirety of the custom operator in Python. So we avoid crossing into C++, unless the kernel actually needs to cross into C++. This avoid us doing things like Python -> C++ (dispatch) -> Python (autograd kernel) -> C++ (dispatch) -> Python (backend kernel).

Downsides

Anything that is not the "fast path" is slower.

Discussion

The claude prototype is still like 2us slower than just invoking the function directly. At a minimum it will be one dispatchkeyset computation slower than invoking the function directly. Is that too much overhead?

Alternatives

???

cc @jerryzh168 @chauhang @penguinwu @bdhirsh @bobrenjc93 @aorenste

extent analysis

Fix Plan

Step 1: Remove Safety Checks from torch.library.custom_op

  • Delete all safety checks from torch.library.custom_op to match the implementation of torch.library.define/torch.library.impl.
  • This should bring the performance of torch.library.custom_op down to 6.3us.

Step 2: Optimize Common Case of torch.library.custom_op

  • Use the claude prototype to optimize the common case of torch.library.custom_op for plain Tensors in eager-mode PyTorch, no subclasses.
  • This should bring the performance of torch.library.custom_op down to 3.4us.

Step 3: Implement Dispatch Key Computation

  • At the start of invoking a Python torch.library.custom_op, perform one dispatch key computation to determine if we're in the "normal eager mode" fast path.
  • The dispatch key computation should check for the following conditions:
    • Plain CPU tensor: [CPU, ADInplaceOrView, AutogradCPU, AutocastCPU]
    • Plain CUDA tensor: [CUDA, ADInplaceOrView, AutogradCUDA, AutocastCUDA]
    • Plain CPU tensor variant 2 (inference_mode): [CPU, AutocastCPU]
    • Plain CUDA tensor variant 2 (inference_mode): [CUDA, AutocastCUDA]

Example Code

import torch

def dispatch_key_computation(tensor):
    # Check if tensor is on CPU or CUDA
    if tensor.is_cuda:
        device = 'CUDA'
    else:
        device = 'CPU'

    # Check if tensor is in inference_mode
    if tensor.inference_mode:
        return [device, 'Autocast' + device]

    # Check if tensor is a plain tensor
    if not isinstance(tensor, torch.Tensor):
        return []

    # Check if tensor has Autograd and

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix [RFC] Make torch.library.custom_op faster in the common case [1 pull requests, 2 comments, 2 participants]