vllm - 💡(How to fix) Fix [Bug]: native Triton top-k/top-p kernel assumes contiguous logits

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

ValueError: Out of range float values are not JSON compliant: nan

Root Cause

noncontig_out = apply_top_k_top_p_triton(logits, None, p)

Rebuild backing because the filter writes in place.

backing2 = torch.full((batch, vocab + pad), -1000.0, device=device, dtype=torch.float32) logits2 = backing2[:, :vocab] logits2.copy_(base[None, :] + torch.arange(batch, device=device)[:, None] / 1000.0) contig_out = apply_top_k_top_p_triton(logits2.contiguous(), None, p)

Fix Action

Fix

Match what forward_cuda already does:

def apply_top_k_top_p_triton(logits, k, p):
    if not logits.is_contiguous():
        logits = logits.contiguous()
    ...

Code Example

ValueError: Out of range float values are not JSON compliant: nan

---

import torch
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton

assert torch.cuda.is_available()
device = "cuda"

batch, vocab, pad = 16, 1024, 8

# Legal non-contiguous [batch, vocab] view with stride (vocab + pad, 1).
# This is the same shape a sliced padded-vocab lm_head produces.
backing = torch.full((batch, vocab + pad), -1000.0, device=device, dtype=torch.float32)
logits = backing[:, :vocab]
assert logits.shape == (batch, vocab)
assert logits.stride() == (vocab + pad, 1)
assert not logits.is_contiguous()

# Descending logits per row (clear top-p semantics).
base = torch.linspace(10.0, -10.0, vocab, device=device)
logits.copy_(base[None, :] + torch.arange(batch, device=device)[:, None] / 1000.0)

p = torch.full((batch,), 0.95, device=device, dtype=torch.float32)

noncontig_out = apply_top_k_top_p_triton(logits, None, p)
# Rebuild backing because the filter writes in place.
backing2 = torch.full((batch, vocab + pad), -1000.0, device=device, dtype=torch.float32)
logits2 = backing2[:, :vocab]
logits2.copy_(base[None, :] + torch.arange(batch, device=device)[:, None] / 1000.0)
contig_out = apply_top_k_top_p_triton(logits2.contiguous(), None, p)

print("non-contiguous finite/row:", torch.isfinite(noncontig_out).sum(-1).tolist())
print("contiguous     finite/row:", torch.isfinite(contig_out).sum(-1).tolist())

assert torch.equal(torch.isfinite(noncontig_out), torch.isfinite(contig_out)), \
    "FAIL: non-contiguous Triton output differs from contiguous Triton."
print("PASS")

---

non-contiguous finite/row: [154, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
contiguous     finite/row: [154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154]
Traceback (most recent call last):
  File "repro.py", line 33, in <module>
    assert torch.equal(torch.isfinite(noncontig_out), torch.isfinite(contig_out)), \
AssertionError: FAIL: non-contiguous Triton output differs from contiguous Triton.

---

bad row 13: logical pre-row max 23.97648, kernel max 34.80531 (= prev logical row max)
bad row 14: logical pre-row max 23.21280, kernel max 35.03107 (= prev logical row max)
bad row 23: logical pre-row max 21.14634, kernel max 34.78697 (= prev logical row max)

---

ValueError: Out of range float values are not JSON compliant: nan
  File "vllm/entrypoints/openai/serving_completion.py", line ..., in ...
    response = CompletionResponse(... logprobs=logprobs ...)
  File "pydantic_core/_pydantic_core.pyx", line ..., in ...
  ...

---

{
  "pre_shape": [32, 100278],
  "pre_bad_finite": [100278],
  "pre_bad_nan": [false],
  "post_all_neginf": [true],
  "post_neginf_count": [100278],
  "p_bad": [0.949999988079071],
  "original_stride": [100288, 1],
  "original_is_contiguous": false
}

---

def apply_top_k_top_p_triton(logits, k, p):
    if not logits.is_contiguous():
        logits = logits.contiguous()
    ...
RAW_BUFFERClick to expand / collapse

Your current environment

vLLM Version: 0.20.1 PyTorch: 2.11.0+cu130 Triton: 3.6.0 CUDA available: True GPU: NVIDIA GH200 120GB Driver: 565.57.01 OS: SUSE Linux Enterprise Server 15 SP6 (aarch64)

🐛 Describe the bug

vllm/v1/sample/ops/topk_topp_triton.py::_topk_topp_kernel computes row pointers as LOGITS + row_id * VOCAB_SIZE. That is correct only for contiguous logits. The wrapper apply_top_k_top_p_triton checks rank and dtype but does not check contiguity, so when the wrapper is handed a logits tensor with stride(0) != vocab_size (e.g. a sliced padded-vocab view), the kernel reads the wrong physical row.

This is observable end-to-end via processed_logprobs: the kernel can mask a logical row to all -inf, logits.log_softmax(dim=-1, dtype=torch.float32) then emits NaN, and JSON serialization fails with:

ValueError: Out of range float values are not JSON compliant: nan

The FlashInfer path already defends against this. vllm/v1/sample/ops/topk_topp_sampler.py::forward_cuda calls flashinfer_sample(logits.contiguous(), ...), with an inline comment noting fp32 inference + logits-processor slicing can produce non-contiguous logits. forward_native does not propagate that guard, and processed_logprobs forces the native path.

Minimal reproducer

import torch
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton

assert torch.cuda.is_available()
device = "cuda"

batch, vocab, pad = 16, 1024, 8

# Legal non-contiguous [batch, vocab] view with stride (vocab + pad, 1).
# This is the same shape a sliced padded-vocab lm_head produces.
backing = torch.full((batch, vocab + pad), -1000.0, device=device, dtype=torch.float32)
logits = backing[:, :vocab]
assert logits.shape == (batch, vocab)
assert logits.stride() == (vocab + pad, 1)
assert not logits.is_contiguous()

# Descending logits per row (clear top-p semantics).
base = torch.linspace(10.0, -10.0, vocab, device=device)
logits.copy_(base[None, :] + torch.arange(batch, device=device)[:, None] / 1000.0)

p = torch.full((batch,), 0.95, device=device, dtype=torch.float32)

noncontig_out = apply_top_k_top_p_triton(logits, None, p)
# Rebuild backing because the filter writes in place.
backing2 = torch.full((batch, vocab + pad), -1000.0, device=device, dtype=torch.float32)
logits2 = backing2[:, :vocab]
logits2.copy_(base[None, :] + torch.arange(batch, device=device)[:, None] / 1000.0)
contig_out = apply_top_k_top_p_triton(logits2.contiguous(), None, p)

print("non-contiguous finite/row:", torch.isfinite(noncontig_out).sum(-1).tolist())
print("contiguous     finite/row:", torch.isfinite(contig_out).sum(-1).tolist())

assert torch.equal(torch.isfinite(noncontig_out), torch.isfinite(contig_out)), \
    "FAIL: non-contiguous Triton output differs from contiguous Triton."
print("PASS")

Observed output (vLLM 0.20.1)

non-contiguous finite/row: [154, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
contiguous     finite/row: [154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154, 154]
Traceback (most recent call last):
  File "repro.py", line 33, in <module>
    assert torch.equal(torch.isfinite(noncontig_out), torch.isfinite(contig_out)), \
AssertionError: FAIL: non-contiguous Triton output differs from contiguous Triton.

A top_k=1 variant of the same script shows the same failure mode: exactly one token should remain finite per row, but for the non-contiguous view some rows retain up to pad-many tokens. Full script: [gist link].

Expected output

Both calls should produce identical finite masks.

Kernel failure

The kernel iterates rows as LOGITS + row_id * VOCAB_SIZE, while the real row stride is vocab + pad. Row N's kernel-side max is the max of some earlier logical row, drifting across rows. Sample per-row capture from a failing production run (Olmo3, vocab=100278, padded=100288):

bad row 13: logical pre-row max 23.97648, kernel max 34.80531 (= prev logical row max)
bad row 14: logical pre-row max 23.21280, kernel max 35.03107 (= prev logical row max)
bad row 23: logical pre-row max 21.14634, kernel max 34.78697 (= prev logical row max)

Traceback

ValueError: Out of range float values are not JSON compliant: nan
  File "vllm/entrypoints/openai/serving_completion.py", line ..., in ...
    response = CompletionResponse(... logprobs=logprobs ...)
  File "pydantic_core/_pydantic_core.pyx", line ..., in ...
  ...

Pre-/post-filter capture for one bad row (production, fp32 logits):

{
  "pre_shape": [32, 100278],
  "pre_bad_finite": [100278],
  "pre_bad_nan": [false],
  "post_all_neginf": [true],
  "post_neginf_count": [100278],
  "p_bad": [0.949999988079071],
  "original_stride": [100288, 1],
  "original_is_contiguous": false
}

Pre-top-p row is finite. Post-top-p row is all -inf. log_softmax(all -inf) -> NaN.

Canaries

Config"Out of range float" errorsBadRequestOUTPUT_NAN
Default (Triton, non-contig in)479478many
Force PyTorch top-k/top-p000
logits.contiguous() before Triton000

Fix

Match what forward_cuda already does:

def apply_top_k_top_p_triton(logits, k, p):
    if not logits.is_contiguous():
        logits = logits.contiguous()
    ...

Before submitting

  • I will set VLLM_LOGGING_LEVEL=DEBUG if asked to share more logs.
  • Repro is self-contained and requires no external data or weights.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - 💡(How to fix) Fix [Bug]: native Triton top-k/top-p kernel assumes contiguous logits