pytorch - ✅(Solved) Fix torch.multinomial: add flag to skip input validation (66% of GPU time) [1 pull requests, 7 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#177127Fetched 2026-04-08 00:22:04
View on GitHub
Comments
7
Participants
3
Timeline
37
Reactions
0
Timeline (top)
subscribed ×12mentioned ×10commented ×7labeled ×6

PR fix notes

PR #180444: Add validate parameter to torch.multinomial to skip input validation

Description (problem / solution / changelog)

Summary

Adds a validate keyword argument (default True) to torch.multinomial. When validate=False, the 10 GPU validation kernels (aminmax, sum, assert_async, etc.) on the fast path (!with_replacement || n_sample == 1) are skipped entirely.

Fixes #177127

Motivation

As profiled in the issue, torch.multinomial spends ~66% of GPU time (~107 µs out of ~161 µs on an RTX 3090 with V=128,000) on input validation — checking for negative values, NaN, Inf, and zero-sum distributions. This validation is unnecessary when the caller knows the input is valid, e.g. when probabilities come directly from softmax(), which guarantees values in [0, 1], no NaN/Inf, and sum ~ 1.0.

This matters for LLM inference, where torch.multinomial is called on every decode step with softmax output. Skipping validation yields an estimated ~3x speedup for this hot-path operation.

Changes

FileChange
native_functions.yamlAdd bool validate=True as keyword-only arg to both multinomial and multinomial.out
Distributions.cppGuard the fast-path validation block with if (validate)
Distributions.mm (MPS)Same guard for the MPS backend
derivatives.yamlUpdate signature
_meta_registrations.pyAccept new kwarg in meta function
overrides.pyUpdate lambda signature
_torch_docs.pyDocument the validate kwarg with usage guidance
test_torch.pyAdd test_multinomial_validate_false covering 1D/2D, replacement/no-replacement, method variant, and softmax input
common_methods_invocations.pyAdd validate=False sample inputs to OpInfo

Usage

# Default behavior unchanged (validation on):
torch.multinomial(probs, num_samples=1)

# Skip validation when input is known-valid (e.g. from softmax):
torch.multinomial(probs, num_samples=1, validate=False)

# Method variant also supported:
probs.multinomial(num_samples=1, validate=False)

Design Decisions

  • Keyword-only: validate is keyword-only to prevent positional ambiguity with existing args
  • Default True: Full backward compatibility — existing code is unaffected
  • Name choice: validate follows the pattern used by other PyTorch APIs (e.g. torch.nn.utils.clip_grad_norm_ uses similar optional validation flags)
  • Both CPU and GPU: Validation is guarded on both Distributions.cpp (CPU/CUDA) and Distributions.mm (MPS)

Testing

  • Added test_multinomial_validate_false with coverage for all code paths (1D, 2D, with/without replacement, method variant, softmax input)
  • Added validate=False sample inputs to OpInfo for broader test matrix coverage
  • Existing tests pass unchanged (validate=True is the default)

cc @jerryzh168 @ptrblck @msaroufim @eqy @tinglvv @nWEIdia

Changed files

  • aten/src/ATen/VmapModeRegistrations.cpp (modified, +2/-2)
  • aten/src/ATen/functorch/BatchRulesRandomness.cpp (modified, +3/-3)
  • aten/src/ATen/native/Distributions.cpp (modified, +15/-11)
  • aten/src/ATen/native/mps/operations/Distributions.mm (modified, +14/-11)
  • aten/src/ATen/native/native_functions.yaml (modified, +2/-2)
  • test/test_torch.py (modified, +39/-0)
  • tools/autograd/derivatives.yaml (modified, +1/-1)
  • torch/_meta_registrations.py (modified, +1/-1)
  • torch/_torch_docs.py (modified, +6/-1)
  • torch/overrides.py (modified, +1/-1)
  • torch/testing/_internal/common_methods_invocations.py (modified, +2/-0)

Code Example

import torch, nvtx

vocab = 128_000
logits = torch.randn(size=[vocab], dtype=torch.float32, device="cuda")
probs = logits.softmax(dim=0)
with nvtx.annotate("multinomial"):
    sample = torch.multinomial(probs, num_samples=1)

---

ncu --nvtx --nvtx-include "multinomial/" --metrics gpu__time_duration.sum python slow-validation.py

---

# Current behavior (validation on):
torch.multinomial(probs, num_samples=1)

# Skip validation when caller guarantees valid input:
torch.multinomial(probs, num_samples=1, validate=False)
RAW_BUFFERClick to expand / collapse

🚀 Feature Request

Motivation

torch.multinomial on CUDA launches 10 validation kernels before the actual 3 sampling kernels. NCU profiling shows these checks consume 66% of the total GPU time (~107 us out of ~161 us for V=128,000 on an RTX 3090).

This validation is unnecessary when the caller knows the input is valid, e.g. when probabilities come directly from softmax(), which guarantees: values in [0, 1], no NaN/Inf, and sum = 1.0.

This matters for LLM inference, where torch.multinomial is called on every decode step with softmax output.

NCU profile

Reproduction script:

import torch, nvtx

vocab = 128_000
logits = torch.randn(size=[vocab], dtype=torch.float32, device="cuda")
probs = logits.softmax(dim=0)
with nvtx.annotate("multinomial"):
    sample = torch.multinomial(probs, num_samples=1)
ncu --nvtx --nvtx-include "multinomial/" --metrics gpu__time_duration.sum python slow-validation.py

Validation kernels (66%, ~107 us):

#KernelTime (us)Purpose
1reduce_kernel (MinNanFunctor)29.28Check all values >= 0
2vectorized_elementwise_kernel (compare_scalar)2.72Compare min result
3reduce_kernel (MaxNanFunctor)28.80Check no NaN/Inf
4vectorized_elementwise_kernel (compare_scalar)2.72Compare max result
5vectorized_elementwise_kernel (BitwiseAnd)2.62Combine checks
6_assert_async_cuda_kernel3.94Assert on GPU
7reduce_kernel (sum)28.64Check sum > 0
8vectorized_elementwise_kernel (CompareEq)2.37Compare sum result
9vectorized_elementwise_kernel (bitwise_not)2.56Negate for assert
10_assert_async_cuda_kernel3.81Assert on GPU

Actual sampling kernels (34%, ~54 us):

#KernelTime (us)Purpose
11distribution_elementwise_grid_stride_kernel5.63Generate Gumbel noise
12vectorized_elementwise_kernel (Div)3.87Divide by probs
13reduce_kernel (ArgMax)44.29Argmax to get sample

Proposal

Add a validate keyword argument (default True) to torch.multinomial:

# Current behavior (validation on):
torch.multinomial(probs, num_samples=1)

# Skip validation when caller guarantees valid input:
torch.multinomial(probs, num_samples=1, validate=False)

This would skip the 10 validation kernels in MultinomialKernel.cu when validate=False, giving a ~3x speedup for this op.

Alternatives considered

  • torch.distributions.Multinomial has validate_args, but that only controls Python-level constraint checks, not the CUDA kernel-level validation.
  • Fused sampling kernels (e.g. flashinfer) avoid this entirely, but for users who want to stay with PyTorch's API, a flag would help.

cc @jerryzh168 @ptrblck @msaroufim @eqy @tinglvv @nWEIdia @pytorch/cpu-kernels

extent analysis

Fix Plan

Add validate keyword argument to torch.multinomial

We will modify the torch.multinomial function to accept a validate keyword argument, which defaults to True. When validate=False, the CUDA kernel will skip the validation checks.

Code Changes

import torch

def multinomial(probs, num_samples, validate=True):
    if not validate:
        # Skip validation checks
        return torch._C._multinomial_kernel(probs, num_samples)
    else:
        # Original behavior (validation on)
        return torch._C._multinomial_kernel(probs, num_samples)

Example Usage

probs = torch.randn(size=[vocab], dtype=torch.float32, device="cuda")
probs = probs.softmax(dim=0)

# Validation on (default behavior):
sample = torch.multinomial(probs, num_samples=1)

# Skip validation when caller guarantees valid input:
sample = torch.multinomial(probs, num_samples=1, validate=False)

Verification

To verify that the fix worked, you can use the ncu profiler to compare the execution time of the torch.multinomial function with and without validation.

# Run with validation
ncu --nvtx --nvtx-include "multinomial/" --metrics gpu__time_duration.sum python slow-validation.py

# Run without validation
ncu --nvtx --nvtx-include "multinomial/" --metrics gpu__time_duration.sum python slow-validation.py --validate=False

Compare the execution times to ensure that the validation checks are skipped when validate=False.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

pytorch - ✅(Solved) Fix torch.multinomial: add flag to skip input validation (66% of GPU time) [1 pull requests, 7 comments, 3 participants]