vllm - 💡(How to fix) Fix [Bug] V1 InputBatch condense can leak stale allowed_token_ids mask to recycled row

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

In current main (4bfa0f2b1458be320fa39c6fa54be5f83cef2444), V1 InputBatch.condense() can leave a stale allowed_token_ids_mask_cpu_tensor row behind after moving a constrained request down. A later unrestricted request can reuse that old row because add_request() only writes allowed_token_ids_mask_cpu_tensor[req_index] when sampling_params.allowed_token_ids is set.

When another active request still has allowed_token_ids, _make_sampling_metadata() copies the active prefix of the CPU mask to GPU and Sampler.apply_logits_processors() applies the stale whitelist to the unrestricted request.

Error Message

#!/usr/bin/env python3 """Drive a mixed vLLM OpenAI server workload for allowed_token_ids row reuse.

The script labels requests with request_id values B, D, A, and C. vLLM's OpenAI completion serving path turns those into engine ids like cmpl-B-0, which match the MASKBUG trace lines emitted by the instrumentation patch. """

from future import annotations

import argparse import json import queue import threading import time from typing import Any

import requests from transformers import AutoTokenizer

TOKEN_CANDIDATES = [ " zebra", " zucchini", " volcano", " skyscraper", " pineapple", " qqq", "!", ]

def wait_for_server(base_url: str, timeout_s: float) -> None: deadline = time.time() + timeout_s last_error: Exception | None = None while time.time() < deadline: try: response = requests.get(f"{base_url}/health", timeout=2) if response.status_code == 200: return except Exception as exc: # noqa: BLE001 last_error = exc time.sleep(1) raise RuntimeError(f"server did not become healthy: {last_error}")

def choose_forced_token(model: str, requested: int | None) -> tuple[int, str]: tokenizer = AutoTokenizer.from_pretrained(model) if requested is not None: return requested, tokenizer.decode([requested])

for text in TOKEN_CANDIDATES:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) == 1:
        return ids[0], text
raise RuntimeError("could not find a one-token distinctive candidate")

def completion_payload( *, model: str, request_id: str, prompt: str, max_tokens: int, allowed_token_ids: list[int] | None = None, stream: bool = True, logprobs: int | None = None, ) -> dict[str, Any]: payload: dict[str, Any] = { "model": model, "request_id": request_id, "prompt": prompt, "max_tokens": max_tokens, "temperature": 0, "ignore_eos": True, "stream": stream, } if allowed_token_ids is not None: payload["allowed_token_ids"] = allowed_token_ids if logprobs is not None: payload["logprobs"] = logprobs payload["return_tokens_as_token_ids"] = True return payload

def stream_completion( base_url: str, label: str, payload: dict[str, Any], done: queue.Queue[tuple[str, str]], stop_event: threading.Event, ) -> None: text_parts: list[str] = [] try: with requests.post( f"{base_url}/v1/completions", json=payload, stream=True, timeout=(5, 300), ) as response: response.raise_for_status() for raw_line in response.iter_lines(): if stop_event.is_set() and label in {"A", "D"}: break if not raw_line: continue line = raw_line.decode("utf-8") if not line.startswith("data: "): continue data = line.removeprefix("data: ") if data == "[DONE]": break chunk = json.loads(data) text_parts.append(chunk["choices"][0].get("text") or "") done.put((label, "".join(text_parts))) except Exception as exc: # noqa: BLE001 done.put((label, f"ERROR: {exc!r}"))

def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--base-url", default="http://127.0.0.1:8000") parser.add_argument("--model", default="facebook/opt-125m") parser.add_argument("--forced-token-id", type=int) parser.add_argument("--server-timeout-s", type=float, default=300) parser.add_argument("--long-tokens", type=int, default=1024) parser.add_argument("--short-tokens", type=int, default=32) parser.add_argument("--post-b-delay-s", type=float, default=0.25) args = parser.parse_args()

wait_for_server(args.base_url, args.server_timeout_s)
forced_id, forced_text = choose_forced_token(args.model, args.forced_token_id)
print(
    json.dumps(
        {
            "event": "forced_token",
            "forced_token_id": forced_id,
            "forced_token_text": forced_text,
        },
        sort_keys=True,
    ),
    flush=True,
)

done: queue.Queue[tuple[str, str]] = queue.Queue()
stop_event = threading.Event()

requests_to_start = [
    (
        "B",
        completion_payload(
            model=args.model,
            request_id="B",
            prompt="Short request B: count upward. 1,",
            max_tokens=args.short_tokens,
            stream=True,
        ),
    ),
    (
        "D",
        completion_payload(
            model=args.model,
            request_id="D",
            prompt="Long unrestricted filler D: repeat a neutral sentence.",
            max_tokens=args.long_tokens,
            stream=True,
        ),
    ),
    (
        "A",
        completion_payload(
            model=args.model,
            request_id="A",
            prompt="Long constrained request A:",
            max_tokens=args.long_tokens,
            allowed_token_ids=[forced_id],
            stream=True,
        ),
    ),
]

for label, payload in requests_to_start:
    thread = threading.Thread(
        target=stream_completion,
        args=(args.base_url, label, payload, done, stop_event),
        daemon=True,
    )
    thread.start()
    print(json.dumps({"event": "started", "label": label}, sort_keys=True), flush=True)
    time.sleep(0.02)

b_text = None
deadline = time.time() + 300
while time.time() < deadline:
    label, text = done.get(timeout=deadline - time.time())
    print(
        json.dumps(
            {
                "event": "stream_done",
                "label": label,
                "text_prefix": text[:120],
                "text_len": len(text),
            },
            sort_keys=True,
        ),
        flush=True,
    )
    if label == "B":
        b_text = text
        break
if b_text is None:
    raise RuntimeError("B did not finish before timeout")

time.sleep(args.post_b_delay_s)

c_payload = completion_payload(
    model=args.model,
    request_id="C",
    prompt="Unrestricted request C: The capital of France is",
    max_tokens=1,
    stream=False,
    logprobs=5,
)
c_response = requests.post(
    f"{args.base_url}/v1/completions",
    json=c_payload,
    timeout=(5, 120),
)
c_response.raise_for_status()
c_json = c_response.json()
c_choice = c_json["choices"][0]
stop_event.set()
print(
    json.dumps(
        {
            "event": "C_result",
            "text": c_choice.get("text"),
            "logprobs": c_choice.get("logprobs"),
            "forced_token_id": forced_id,
            "forced_token_text": forced_text,
            "full_response_id": c_json.get("id"),
        },
        sort_keys=True,
    ),
    flush=True,
)
return 0

if name == "main": raise SystemExit(main())

Root Cause

In current main (4bfa0f2b1458be320fa39c6fa54be5f83cef2444), V1 InputBatch.condense() can leave a stale allowed_token_ids_mask_cpu_tensor row behind after moving a constrained request down. A later unrestricted request can reuse that old row because add_request() only writes allowed_token_ids_mask_cpu_tensor[req_index] when sampling_params.allowed_token_ids is set.

Fix Action

Fix / Workaround

Patched result for the same command with --expect fixed:

Patched behavior for the same workload:

Patched row-transition trace excerpt:

Code Example

python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect stale

---

{
  "active_req_id_to_index": {"req1": 1, "req2": 0, "req3": 2},
  "req3_index": 2,
  "req3_metadata_row": {"false_count": 1, "false_ids": [13], "true_count": 15},
  "stale_mask_reachable": true
}

---

{
  "active_req_id_to_index": {"req1": 1, "req2": 0, "req3": 2},
  "req3_index": 2,
  "req3_metadata_row": {"false_count": 16, "false_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "true_count": 0},
  "stale_mask_reachable": false
}

---

#!/usr/bin/env python3
"""Focused repro for stale allowed_token_ids rows in InputBatch.

Run from a vLLM checkout after installing vLLM:

    python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect stale
    python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect fixed
"""

from __future__ import annotations

import argparse
import hashlib
import json
import sys
from typing import Any

import torch

from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch


VOCAB_SIZE = 16
ALLOWED_TOKEN_ID = 13


def row_digest(row: torch.Tensor) -> dict[str, Any]:
    row_cpu = row.detach().cpu().to(torch.bool)
    false_ids = torch.nonzero(~row_cpu, as_tuple=False).flatten().tolist()
    return {
        "true_count": int(row_cpu.sum().item()),
        "false_count": int((~row_cpu).sum().item()),
        "false_ids": false_ids,
        "sha1": hashlib.sha1(row_cpu.numpy().tobytes()).hexdigest()[:12],
    }


def make_request(req_id: str, allowed_token_ids: list[int] | None) -> CachedRequestState:
    return CachedRequestState(
        req_id=req_id,
        prompt_token_ids=[1],
        mm_features=[],
        sampling_params=SamplingParams(
            temperature=0.0,
            max_tokens=1,
            allowed_token_ids=allowed_token_ids,
        ),
        pooling_params=None,
        block_ids=([],),
        generator=None,
        num_computed_tokens=0,
        output_token_ids=[],
    )


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--device",
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="Device for InputBatch tensors.",
    )
    parser.add_argument(
        "--expect",
        choices=("stale", "fixed", "either"),
        default="either",
        help="Expected behavior for req3's unrestricted mask row.",
    )
    args = parser.parse_args()

    device = torch.device(args.device)
    batch = InputBatch(
        max_num_reqs=4,
        max_model_len=8,
        max_num_batched_tokens=16,
        device=device,
        pin_memory=False,
        vocab_size=VOCAB_SIZE,
        block_sizes=[1],
        kernel_block_sizes=[1],
    )

    transitions: list[dict[str, Any]] = []

    def record(event: str) -> None:
        cpu_mask = batch.allowed_token_ids_mask_cpu_tensor
        rows = {}
        if cpu_mask is not None:
            for i in range(batch.max_num_reqs):
                rows[str(i)] = row_digest(cpu_mask[i])
        transitions.append(
            {
                "event": event,
                "req_ids": list(batch._req_ids),
                "req_id_to_index": dict(batch.req_id_to_index),
                "has_allowed_token_ids": sorted(batch.has_allowed_token_ids),
                "rows": rows,
            }
        )

    for req_id, allowed in (
        ("req0", None),
        ("req1", None),
        ("req2", [ALLOWED_TOKEN_ID]),
    ):
        req_index = batch.add_request(make_request(req_id, allowed))
        record(f"add {req_id} -> row {req_index}")

    removed_index = batch.remove_request("req0")
    record(f"remove req0 from row {removed_index}")

    batch.condense()
    record("condense")

    req3_index = batch.add_request(make_request("req3", None))
    record(f"add req3 -> row {req3_index}")

    metadata = batch._make_sampling_metadata()
    if metadata.allowed_token_ids_mask is None:
        print("allowed_token_ids_mask unexpectedly absent", file=sys.stderr)
        return 2
    if device.type == "cuda":
        torch.cuda.synchronize(device)

    metadata_mask = metadata.allowed_token_ids_mask.detach().cpu()
    req3_mask = metadata_mask[req3_index]
    req3_digest = row_digest(req3_mask)
    stale = bool(req3_mask.any().item())

    result = {
        "device": str(device),
        "allowed_token_id": ALLOWED_TOKEN_ID,
        "req3_index": req3_index,
        "req3_metadata_row": req3_digest,
        "stale_mask_reachable": stale,
        "active_req_id_to_index": dict(batch.req_id_to_index),
        "transitions": transitions,
    }
    print(json.dumps(result, indent=2, sort_keys=True))

    if args.expect == "stale" and not stale:
        print("expected stale mask, but req3 row was unrestricted", file=sys.stderr)
        return 1
    if args.expect == "fixed" and stale:
        print("expected unrestricted req3 row, but stale mask is present", file=sys.stderr)
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

---

VLLM_MASKBUG_TRACE=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server   --model facebook/opt-125m   --host 127.0.0.1   --port 8000   --dtype float16   --max-num-seqs 8   --max-model-len 2048   --gpu-memory-utilization 0.25   --enforce-eager

---

python repro_server_allowed_token_ids_mask.py   --base-url http://127.0.0.1:8000   --model facebook/opt-125m   --long-tokens 1024   --short-tokens 32

---

{
  "event": "C_result",
  "text": " volcano",
  "logprobs": {
    "tokens": ["token_id:17321"],
    "token_logprobs": [-12.824044227600098],
    "top_logprobs": [{
      "token_id:5": -2.6677942276000977,
      "token_id:10": -3.1717004776000977,
      "token_id:1470": -3.3006067276000977,
      "token_id:17321": -12.824044227600098
    }]
  }
}

---

MASKBUG add req_id=cmpl-B... req_index=0 has_allowed=[] has_allowed_param=False mask=None
MASKBUG add req_id=cmpl-D... req_index=1 has_allowed=[] has_allowed_param=False mask=None
MASKBUG add req_id=cmpl-A... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=True rows=... 2:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG remove req_id=cmpl-B... req_index=0 has_allowed=['cmpl-A...'] rows=0:None:true=0:false=50272 ... 2:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG condense_move req_id=cmpl-A... req_index=0 has_allowed=['cmpl-A...'] from=2 to=0 rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:None:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG add req_id=cmpl-C... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=False rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:cmpl-C...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG make_sampling_metadata ... num_reqs=3 rows=... 2:cmpl-C...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535

---

{
  "event": "C_result",
  "text": " the",
  "logprobs": {
    "tokens": ["token_id:5"],
    "token_logprobs": [-2.6677942276000977],
    "top_logprobs": [{
      "token_id:5": -2.6677942276000977,
      "token_id:10": -3.1717004776000977,
      "token_id:1470": -3.3006067276000977
    }]
  }
}

---

MASKBUG add req_id=cmpl-C... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=False rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:cmpl-C...:true=0:false=50272:false_ids=[0, 1, 2, 3, 4, 5, 6, 7]:true_idx_sum=0
MASKBUG make_sampling_metadata ... num_reqs=3 rows=... 2:cmpl-C...:true=0:false=50272:false_ids=[0, 1, 2, 3, 4, 5, 6, 7]:true_idx_sum=0

---

#!/usr/bin/env python3
"""Drive a mixed vLLM OpenAI server workload for allowed_token_ids row reuse.

The script labels requests with request_id values B, D, A, and C. vLLM's
OpenAI completion serving path turns those into engine ids like cmpl-B-0,
which match the MASKBUG trace lines emitted by the instrumentation patch.
"""

from __future__ import annotations

import argparse
import json
import queue
import threading
import time
from typing import Any

import requests
from transformers import AutoTokenizer


TOKEN_CANDIDATES = [
    " zebra",
    " zucchini",
    " volcano",
    " skyscraper",
    " pineapple",
    " qqq",
    "!",
]


def wait_for_server(base_url: str, timeout_s: float) -> None:
    deadline = time.time() + timeout_s
    last_error: Exception | None = None
    while time.time() < deadline:
        try:
            response = requests.get(f"{base_url}/health", timeout=2)
            if response.status_code == 200:
                return
        except Exception as exc:  # noqa: BLE001
            last_error = exc
        time.sleep(1)
    raise RuntimeError(f"server did not become healthy: {last_error}")


def choose_forced_token(model: str, requested: int | None) -> tuple[int, str]:
    tokenizer = AutoTokenizer.from_pretrained(model)
    if requested is not None:
        return requested, tokenizer.decode([requested])

    for text in TOKEN_CANDIDATES:
        ids = tokenizer.encode(text, add_special_tokens=False)
        if len(ids) == 1:
            return ids[0], text
    raise RuntimeError("could not find a one-token distinctive candidate")


def completion_payload(
    *,
    model: str,
    request_id: str,
    prompt: str,
    max_tokens: int,
    allowed_token_ids: list[int] | None = None,
    stream: bool = True,
    logprobs: int | None = None,
) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": model,
        "request_id": request_id,
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": 0,
        "ignore_eos": True,
        "stream": stream,
    }
    if allowed_token_ids is not None:
        payload["allowed_token_ids"] = allowed_token_ids
    if logprobs is not None:
        payload["logprobs"] = logprobs
        payload["return_tokens_as_token_ids"] = True
    return payload


def stream_completion(
    base_url: str,
    label: str,
    payload: dict[str, Any],
    done: queue.Queue[tuple[str, str]],
    stop_event: threading.Event,
) -> None:
    text_parts: list[str] = []
    try:
        with requests.post(
            f"{base_url}/v1/completions",
            json=payload,
            stream=True,
            timeout=(5, 300),
        ) as response:
            response.raise_for_status()
            for raw_line in response.iter_lines():
                if stop_event.is_set() and label in {"A", "D"}:
                    break
                if not raw_line:
                    continue
                line = raw_line.decode("utf-8")
                if not line.startswith("data: "):
                    continue
                data = line.removeprefix("data: ")
                if data == "[DONE]":
                    break
                chunk = json.loads(data)
                text_parts.append(chunk["choices"][0].get("text") or "")
        done.put((label, "".join(text_parts)))
    except Exception as exc:  # noqa: BLE001
        done.put((label, f"ERROR: {exc!r}"))


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-url", default="http://127.0.0.1:8000")
    parser.add_argument("--model", default="facebook/opt-125m")
    parser.add_argument("--forced-token-id", type=int)
    parser.add_argument("--server-timeout-s", type=float, default=300)
    parser.add_argument("--long-tokens", type=int, default=1024)
    parser.add_argument("--short-tokens", type=int, default=32)
    parser.add_argument("--post-b-delay-s", type=float, default=0.25)
    args = parser.parse_args()

    wait_for_server(args.base_url, args.server_timeout_s)
    forced_id, forced_text = choose_forced_token(args.model, args.forced_token_id)
    print(
        json.dumps(
            {
                "event": "forced_token",
                "forced_token_id": forced_id,
                "forced_token_text": forced_text,
            },
            sort_keys=True,
        ),
        flush=True,
    )

    done: queue.Queue[tuple[str, str]] = queue.Queue()
    stop_event = threading.Event()

    requests_to_start = [
        (
            "B",
            completion_payload(
                model=args.model,
                request_id="B",
                prompt="Short request B: count upward. 1,",
                max_tokens=args.short_tokens,
                stream=True,
            ),
        ),
        (
            "D",
            completion_payload(
                model=args.model,
                request_id="D",
                prompt="Long unrestricted filler D: repeat a neutral sentence.",
                max_tokens=args.long_tokens,
                stream=True,
            ),
        ),
        (
            "A",
            completion_payload(
                model=args.model,
                request_id="A",
                prompt="Long constrained request A:",
                max_tokens=args.long_tokens,
                allowed_token_ids=[forced_id],
                stream=True,
            ),
        ),
    ]

    for label, payload in requests_to_start:
        thread = threading.Thread(
            target=stream_completion,
            args=(args.base_url, label, payload, done, stop_event),
            daemon=True,
        )
        thread.start()
        print(json.dumps({"event": "started", "label": label}, sort_keys=True), flush=True)
        time.sleep(0.02)

    b_text = None
    deadline = time.time() + 300
    while time.time() < deadline:
        label, text = done.get(timeout=deadline - time.time())
        print(
            json.dumps(
                {
                    "event": "stream_done",
                    "label": label,
                    "text_prefix": text[:120],
                    "text_len": len(text),
                },
                sort_keys=True,
            ),
            flush=True,
        )
        if label == "B":
            b_text = text
            break
    if b_text is None:
        raise RuntimeError("B did not finish before timeout")

    time.sleep(args.post_b_delay_s)

    c_payload = completion_payload(
        model=args.model,
        request_id="C",
        prompt="Unrestricted request C: The capital of France is",
        max_tokens=1,
        stream=False,
        logprobs=5,
    )
    c_response = requests.post(
        f"{args.base_url}/v1/completions",
        json=c_payload,
        timeout=(5, 120),
    )
    c_response.raise_for_status()
    c_json = c_response.json()
    c_choice = c_json["choices"][0]
    stop_event.set()
    print(
        json.dumps(
            {
                "event": "C_result",
                "text": c_choice.get("text"),
                "logprobs": c_choice.get("logprobs"),
                "forced_token_id": forced_id,
                "forced_token_text": forced_text,
                "full_response_id": c_json.get("id"),
            },
            sort_keys=True,
        ),
        flush=True,
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

---

diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -446,6 +446,9 @@ class InputBatch:
                 self.allowed_token_ids_mask_cpu_tensor[req_index][
                     sampling_params.allowed_token_ids
                 ] = False
+            elif self.allowed_token_ids_mask_cpu_tensor is not None:
+                # False means unrestricted: do not fill any logits with -inf.
+                self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
 
             if sampling_params.bad_words_token_ids:
                 self.bad_words_token_ids[req_index] = (
@@ -794,8 +797,10 @@ class InputBatch:
             if self.allowed_token_ids_mask_cpu_tensor is not None:
                 self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                     self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                 )
+                # The source row is now empty and may be reused by a new request.
+                self.allowed_token_ids_mask_cpu_tensor[last_req_index].fill_(False)
 
             bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
             if bad_words_token_ids is not None:
RAW_BUFFERClick to expand / collapse

Summary

In current main (4bfa0f2b1458be320fa39c6fa54be5f83cef2444), V1 InputBatch.condense() can leave a stale allowed_token_ids_mask_cpu_tensor row behind after moving a constrained request down. A later unrestricted request can reuse that old row because add_request() only writes allowed_token_ids_mask_cpu_tensor[req_index] when sampling_params.allowed_token_ids is set.

When another active request still has allowed_token_ids, _make_sampling_metadata() copies the active prefix of the CPU mask to GPU and Sampler.apply_logits_processors() applies the stale whitelist to the unrestricted request.

Duplicate search

I searched open issues and PRs in vllm-project/vllm for:

  • allowed_token_ids_mask condense
  • allowed_token_ids stale InputBatch
  • allowed_token_ids dynamic batching
  • allowed_token_ids recycled row

No open issue or PR hits were returned by GitHub search for those queries.

Environment

  • GCP VM: a2-highgpu-1g, 1x NVIDIA A100-SXM4-40GB
  • Zone: us-central1-a
  • Driver: 580.159.03
  • vLLM source: 4bfa0f2b1458be320fa39c6fa54be5f83cef2444
  • vLLM version: 0.1.dev1+g4bfa0f2b1
  • PyTorch: 2.11.0+cu130
  • CUDA visible: True, GPU name NVIDIA A100-SXM4-40GB
  • Install mode: editable source install with precompiled native artifacts from latest available main/nightly wheel. The code under test is the checked-out Python source at the SHA above.

Focused state-machine repro

Command:

python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect stale

Current result:

{
  "active_req_id_to_index": {"req1": 1, "req2": 0, "req3": 2},
  "req3_index": 2,
  "req3_metadata_row": {"false_count": 1, "false_ids": [13], "true_count": 15},
  "stale_mask_reachable": true
}

Important transition from the focused run:

  • req2 starts at row 2 with allowed_token_ids=[13].
  • remove req0 clears row 0.
  • condense moves req2 from row 2 to row 0, but row 2 remains true_count=15, false_ids=[13].
  • add req3 unrestricted reuses row 2, and row 2 is still true_count=15, false_ids=[13].

Patched result for the same command with --expect fixed:

{
  "active_req_id_to_index": {"req1": 1, "req2": 0, "req3": 2},
  "req3_index": 2,
  "req3_metadata_row": {"false_count": 16, "false_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "true_count": 0},
  "stale_mask_reachable": false
}
<details> <summary>Focused repro script</summary>
#!/usr/bin/env python3
"""Focused repro for stale allowed_token_ids rows in InputBatch.

Run from a vLLM checkout after installing vLLM:

    python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect stale
    python repro_allowed_token_ids_mask_state.py --device cuda:0 --expect fixed
"""

from __future__ import annotations

import argparse
import hashlib
import json
import sys
from typing import Any

import torch

from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch


VOCAB_SIZE = 16
ALLOWED_TOKEN_ID = 13


def row_digest(row: torch.Tensor) -> dict[str, Any]:
    row_cpu = row.detach().cpu().to(torch.bool)
    false_ids = torch.nonzero(~row_cpu, as_tuple=False).flatten().tolist()
    return {
        "true_count": int(row_cpu.sum().item()),
        "false_count": int((~row_cpu).sum().item()),
        "false_ids": false_ids,
        "sha1": hashlib.sha1(row_cpu.numpy().tobytes()).hexdigest()[:12],
    }


def make_request(req_id: str, allowed_token_ids: list[int] | None) -> CachedRequestState:
    return CachedRequestState(
        req_id=req_id,
        prompt_token_ids=[1],
        mm_features=[],
        sampling_params=SamplingParams(
            temperature=0.0,
            max_tokens=1,
            allowed_token_ids=allowed_token_ids,
        ),
        pooling_params=None,
        block_ids=([],),
        generator=None,
        num_computed_tokens=0,
        output_token_ids=[],
    )


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--device",
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="Device for InputBatch tensors.",
    )
    parser.add_argument(
        "--expect",
        choices=("stale", "fixed", "either"),
        default="either",
        help="Expected behavior for req3's unrestricted mask row.",
    )
    args = parser.parse_args()

    device = torch.device(args.device)
    batch = InputBatch(
        max_num_reqs=4,
        max_model_len=8,
        max_num_batched_tokens=16,
        device=device,
        pin_memory=False,
        vocab_size=VOCAB_SIZE,
        block_sizes=[1],
        kernel_block_sizes=[1],
    )

    transitions: list[dict[str, Any]] = []

    def record(event: str) -> None:
        cpu_mask = batch.allowed_token_ids_mask_cpu_tensor
        rows = {}
        if cpu_mask is not None:
            for i in range(batch.max_num_reqs):
                rows[str(i)] = row_digest(cpu_mask[i])
        transitions.append(
            {
                "event": event,
                "req_ids": list(batch._req_ids),
                "req_id_to_index": dict(batch.req_id_to_index),
                "has_allowed_token_ids": sorted(batch.has_allowed_token_ids),
                "rows": rows,
            }
        )

    for req_id, allowed in (
        ("req0", None),
        ("req1", None),
        ("req2", [ALLOWED_TOKEN_ID]),
    ):
        req_index = batch.add_request(make_request(req_id, allowed))
        record(f"add {req_id} -> row {req_index}")

    removed_index = batch.remove_request("req0")
    record(f"remove req0 from row {removed_index}")

    batch.condense()
    record("condense")

    req3_index = batch.add_request(make_request("req3", None))
    record(f"add req3 -> row {req3_index}")

    metadata = batch._make_sampling_metadata()
    if metadata.allowed_token_ids_mask is None:
        print("allowed_token_ids_mask unexpectedly absent", file=sys.stderr)
        return 2
    if device.type == "cuda":
        torch.cuda.synchronize(device)

    metadata_mask = metadata.allowed_token_ids_mask.detach().cpu()
    req3_mask = metadata_mask[req3_index]
    req3_digest = row_digest(req3_mask)
    stale = bool(req3_mask.any().item())

    result = {
        "device": str(device),
        "allowed_token_id": ALLOWED_TOKEN_ID,
        "req3_index": req3_index,
        "req3_metadata_row": req3_digest,
        "stale_mask_reachable": stale,
        "active_req_id_to_index": dict(batch.req_id_to_index),
        "transitions": transitions,
    }
    print(json.dumps(result, indent=2, sort_keys=True))

    if args.expect == "stale" and not stale:
        print("expected stale mask, but req3 row was unrestricted", file=sys.stderr)
        return 1
    if args.expect == "fixed" and stale:
        print("expected unrestricted req3 row, but stale mask is present", file=sys.stderr)
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
</details>

End-to-end server repro

Server command:

VLLM_MASKBUG_TRACE=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server   --model facebook/opt-125m   --host 127.0.0.1   --port 8000   --dtype float16   --max-num-seqs 8   --max-model-len 2048   --gpu-memory-utilization 0.25   --enforce-eager

Client command:

python repro_server_allowed_token_ids_mask.py   --base-url http://127.0.0.1:8000   --model facebook/opt-125m   --long-tokens 1024   --short-tokens 32

The driver chooses a one-token distinctive whitelist token. In this run it selected token_id:17321, text " volcano".

Current behavior:

{
  "event": "C_result",
  "text": " volcano",
  "logprobs": {
    "tokens": ["token_id:17321"],
    "token_logprobs": [-12.824044227600098],
    "top_logprobs": [{
      "token_id:5": -2.6677942276000977,
      "token_id:10": -3.1717004776000977,
      "token_id:1470": -3.3006067276000977,
      "token_id:17321": -12.824044227600098
    }]
  }
}

Even though token 17321 is much worse than token 5 in the raw top logprobs, unrestricted request C sampled token 17321 because its row inherited A's stale whitelist.

Current row-transition trace excerpts:

MASKBUG add req_id=cmpl-B... req_index=0 has_allowed=[] has_allowed_param=False mask=None
MASKBUG add req_id=cmpl-D... req_index=1 has_allowed=[] has_allowed_param=False mask=None
MASKBUG add req_id=cmpl-A... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=True rows=... 2:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG remove req_id=cmpl-B... req_index=0 has_allowed=['cmpl-A...'] rows=0:None:true=0:false=50272 ... 2:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG condense_move req_id=cmpl-A... req_index=0 has_allowed=['cmpl-A...'] from=2 to=0 rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:None:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG add req_id=cmpl-C... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=False rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:cmpl-C...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535
MASKBUG make_sampling_metadata ... num_reqs=3 rows=... 2:cmpl-C...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535

Patched behavior for the same workload:

{
  "event": "C_result",
  "text": " the",
  "logprobs": {
    "tokens": ["token_id:5"],
    "token_logprobs": [-2.6677942276000977],
    "top_logprobs": [{
      "token_id:5": -2.6677942276000977,
      "token_id:10": -3.1717004776000977,
      "token_id:1470": -3.3006067276000977
    }]
  }
}

Patched row-transition trace excerpt:

MASKBUG add req_id=cmpl-C... req_index=2 has_allowed=['cmpl-A...'] has_allowed_param=False rows=0:cmpl-A...:true=50271:false=1:false_ids=[17321]:true_idx_sum=1263594535 | 1:cmpl-D...:true=0:false=50272 | 2:cmpl-C...:true=0:false=50272:false_ids=[0, 1, 2, 3, 4, 5, 6, 7]:true_idx_sum=0
MASKBUG make_sampling_metadata ... num_reqs=3 rows=... 2:cmpl-C...:true=0:false=50272:false_ids=[0, 1, 2, 3, 4, 5, 6, 7]:true_idx_sum=0

Note: in the patched instrumentation, the condense_move trace line is emitted immediately before the added source-row clear, so the more relevant patched evidence is the subsequent add C and make_sampling_metadata rows.

<details> <summary>Server workload script</summary>
#!/usr/bin/env python3
"""Drive a mixed vLLM OpenAI server workload for allowed_token_ids row reuse.

The script labels requests with request_id values B, D, A, and C. vLLM's
OpenAI completion serving path turns those into engine ids like cmpl-B-0,
which match the MASKBUG trace lines emitted by the instrumentation patch.
"""

from __future__ import annotations

import argparse
import json
import queue
import threading
import time
from typing import Any

import requests
from transformers import AutoTokenizer


TOKEN_CANDIDATES = [
    " zebra",
    " zucchini",
    " volcano",
    " skyscraper",
    " pineapple",
    " qqq",
    "!",
]


def wait_for_server(base_url: str, timeout_s: float) -> None:
    deadline = time.time() + timeout_s
    last_error: Exception | None = None
    while time.time() < deadline:
        try:
            response = requests.get(f"{base_url}/health", timeout=2)
            if response.status_code == 200:
                return
        except Exception as exc:  # noqa: BLE001
            last_error = exc
        time.sleep(1)
    raise RuntimeError(f"server did not become healthy: {last_error}")


def choose_forced_token(model: str, requested: int | None) -> tuple[int, str]:
    tokenizer = AutoTokenizer.from_pretrained(model)
    if requested is not None:
        return requested, tokenizer.decode([requested])

    for text in TOKEN_CANDIDATES:
        ids = tokenizer.encode(text, add_special_tokens=False)
        if len(ids) == 1:
            return ids[0], text
    raise RuntimeError("could not find a one-token distinctive candidate")


def completion_payload(
    *,
    model: str,
    request_id: str,
    prompt: str,
    max_tokens: int,
    allowed_token_ids: list[int] | None = None,
    stream: bool = True,
    logprobs: int | None = None,
) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": model,
        "request_id": request_id,
        "prompt": prompt,
        "max_tokens": max_tokens,
        "temperature": 0,
        "ignore_eos": True,
        "stream": stream,
    }
    if allowed_token_ids is not None:
        payload["allowed_token_ids"] = allowed_token_ids
    if logprobs is not None:
        payload["logprobs"] = logprobs
        payload["return_tokens_as_token_ids"] = True
    return payload


def stream_completion(
    base_url: str,
    label: str,
    payload: dict[str, Any],
    done: queue.Queue[tuple[str, str]],
    stop_event: threading.Event,
) -> None:
    text_parts: list[str] = []
    try:
        with requests.post(
            f"{base_url}/v1/completions",
            json=payload,
            stream=True,
            timeout=(5, 300),
        ) as response:
            response.raise_for_status()
            for raw_line in response.iter_lines():
                if stop_event.is_set() and label in {"A", "D"}:
                    break
                if not raw_line:
                    continue
                line = raw_line.decode("utf-8")
                if not line.startswith("data: "):
                    continue
                data = line.removeprefix("data: ")
                if data == "[DONE]":
                    break
                chunk = json.loads(data)
                text_parts.append(chunk["choices"][0].get("text") or "")
        done.put((label, "".join(text_parts)))
    except Exception as exc:  # noqa: BLE001
        done.put((label, f"ERROR: {exc!r}"))


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-url", default="http://127.0.0.1:8000")
    parser.add_argument("--model", default="facebook/opt-125m")
    parser.add_argument("--forced-token-id", type=int)
    parser.add_argument("--server-timeout-s", type=float, default=300)
    parser.add_argument("--long-tokens", type=int, default=1024)
    parser.add_argument("--short-tokens", type=int, default=32)
    parser.add_argument("--post-b-delay-s", type=float, default=0.25)
    args = parser.parse_args()

    wait_for_server(args.base_url, args.server_timeout_s)
    forced_id, forced_text = choose_forced_token(args.model, args.forced_token_id)
    print(
        json.dumps(
            {
                "event": "forced_token",
                "forced_token_id": forced_id,
                "forced_token_text": forced_text,
            },
            sort_keys=True,
        ),
        flush=True,
    )

    done: queue.Queue[tuple[str, str]] = queue.Queue()
    stop_event = threading.Event()

    requests_to_start = [
        (
            "B",
            completion_payload(
                model=args.model,
                request_id="B",
                prompt="Short request B: count upward. 1,",
                max_tokens=args.short_tokens,
                stream=True,
            ),
        ),
        (
            "D",
            completion_payload(
                model=args.model,
                request_id="D",
                prompt="Long unrestricted filler D: repeat a neutral sentence.",
                max_tokens=args.long_tokens,
                stream=True,
            ),
        ),
        (
            "A",
            completion_payload(
                model=args.model,
                request_id="A",
                prompt="Long constrained request A:",
                max_tokens=args.long_tokens,
                allowed_token_ids=[forced_id],
                stream=True,
            ),
        ),
    ]

    for label, payload in requests_to_start:
        thread = threading.Thread(
            target=stream_completion,
            args=(args.base_url, label, payload, done, stop_event),
            daemon=True,
        )
        thread.start()
        print(json.dumps({"event": "started", "label": label}, sort_keys=True), flush=True)
        time.sleep(0.02)

    b_text = None
    deadline = time.time() + 300
    while time.time() < deadline:
        label, text = done.get(timeout=deadline - time.time())
        print(
            json.dumps(
                {
                    "event": "stream_done",
                    "label": label,
                    "text_prefix": text[:120],
                    "text_len": len(text),
                },
                sort_keys=True,
            ),
            flush=True,
        )
        if label == "B":
            b_text = text
            break
    if b_text is None:
        raise RuntimeError("B did not finish before timeout")

    time.sleep(args.post_b_delay_s)

    c_payload = completion_payload(
        model=args.model,
        request_id="C",
        prompt="Unrestricted request C: The capital of France is",
        max_tokens=1,
        stream=False,
        logprobs=5,
    )
    c_response = requests.post(
        f"{args.base_url}/v1/completions",
        json=c_payload,
        timeout=(5, 120),
    )
    c_response.raise_for_status()
    c_json = c_response.json()
    c_choice = c_json["choices"][0]
    stop_event.set()
    print(
        json.dumps(
            {
                "event": "C_result",
                "text": c_choice.get("text"),
                "logprobs": c_choice.get("logprobs"),
                "forced_token_id": forced_id,
                "forced_token_text": forced_text,
                "full_response_id": c_json.get("id"),
            },
            sort_keys=True,
        ),
        flush=True,
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
</details>

Minimal patch

diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py
--- a/vllm/v1/worker/gpu_input_batch.py
+++ b/vllm/v1/worker/gpu_input_batch.py
@@ -446,6 +446,9 @@ class InputBatch:
                 self.allowed_token_ids_mask_cpu_tensor[req_index][
                     sampling_params.allowed_token_ids
                 ] = False
+            elif self.allowed_token_ids_mask_cpu_tensor is not None:
+                # False means unrestricted: do not fill any logits with -inf.
+                self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
 
             if sampling_params.bad_words_token_ids:
                 self.bad_words_token_ids[req_index] = (
@@ -794,8 +797,10 @@ class InputBatch:
             if self.allowed_token_ids_mask_cpu_tensor is not None:
                 self.allowed_token_ids_mask_cpu_tensor[empty_index] = (
                     self.allowed_token_ids_mask_cpu_tensor[last_req_index]
                 )
+                # The source row is now empty and may be reused by a new request.
+                self.allowed_token_ids_mask_cpu_tensor[last_req_index].fill_(False)
 
             bad_words_token_ids = self.bad_words_token_ids.pop(last_req_index, None)
             if bad_words_token_ids is not None:

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING