vllm - ✅(Solved) Fix [Bug]: UBatch CUDA graph capture stores graph under first-two-microbatch token count when ubatch_size > 2 [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#43145Fetched 2026-05-20 03:39:41
View on GitHub
Comments
0
Participants
1
Timeline
1
Reactions
0
Participants
Timeline (top)
cross-referenced ×1

When ubatch_size > 2 and FULL CUDA graph capture is active, UBatchWrapper.__call__() uses the total token count across all ubatches as the graph cache lookup key, but _capture_ubatches() stores the captured graph under the sum of only ubatch_metadata[0] and ubatch_metadata[1]. This causes captured graphs to be stored under the wrong key for 3+ microbatches.

Error Message

ValueError: Non-MoE models do not support external data parallel mode.

Root Cause

The same real shape recaptures because key 12 is still absent:

Fix Action

Fixed

PR fix notes

PR #43161: [Bugfix] Fix UBatchWrapper CUDA graph key to sum all ubatches, not just first two

Description (problem / solution / changelog)

When ubatch_size > 2, _capture_ubatches() stored the captured graph under ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens (first two ubatches only), but __call__ looked it up with sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices) (all ubatches). The key mismatch caused the cached graph to never be found, triggering unnecessary re-captures on every call.

Fixes #43145

Purpose

Fix CUDA graph cache miss in UBatchWrapper._capture_ubatches() when ubatch_size > 2 (#43145).

Test Plan

Repro script that extracts the actual key formula from the installed source and evaluates it against test cases:

Test Result

Before fix (stock):

Key formula: num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens

[BUG] 3 ubatches × 4 tokens  (total=12)  stored_key=8   lookup_key=12  ✗ MISMATCH
[BUG] 3 ubatches mixed        (total=10)  stored_key=8   lookup_key=10  ✗ MISMATCH
[BUG] 4 ubatches × 2 tokens  (total=8)   stored_key=4   lookup_key=8   ✗ MISMATCH
[OK ] 2 ubatches              (total=16)  stored_key=16  lookup_key=16  ✓

After fix:

Key formula: num_tokens = sum(m.num_tokens for m in ubatch_metadata)

[OK ] 3 ubatches × 4 tokens  (total=12)  stored_key=12  lookup_key=12  ✓
[OK ] 3 ubatches mixed        (total=10)  stored_key=10  lookup_key=10  ✓
[OK ] 4 ubatches × 2 tokens  (total=8)   stored_key=8   lookup_key=8   ✓
[OK ] 2 ubatches              (total=16)  stored_key=16  lookup_key=16  ✓

Changed files

  • vllm/v1/worker/gpu_ubatch_wrapper.py (modified, +3/-3)

Code Example

vllm serve --help=all | grep -i ubatch

---

VLLM_LOGGING_LEVEL=DEBUG vllm serve facebook/opt-125m \
  --dtype half \
  --max-model-len 1024 \
  --max-num-seqs 64 \
  --max-num-batched-tokens 1024 \
  --ubatch-size 3 \
  --all2all-backend deepep_high_throughput \
  --compilation-config '{"cudagraph_mode":"FULL","cudagraph_capture_sizes":[8,12,16,24,32,48,64]}'

---

ubatch_slices: None, ubatch_slices_padded: None

---

ValueError: Non-MoE models do not support external data parallel mode.

---

DeepEP: Disabling CUDA Graphs since DeepEP high-throughput kernels are optimized for prefill and are incompatible with CUDA Graphs.

---

AssertionError: DeepEP kernels not found.

---

# __call__
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
logger.warning(
    "UBATCH_DEBUG lookup lookup_key=%s ubatch_slice_tokens=%s cache_keys=%s cudagraph_runtime_mode=%s",
    num_tokens,
    [ubatch_slice.num_tokens for ubatch_slice in ubatch_slices],
    list(self.cudagraphs.keys()),
    cudagraph_runtime_mode,
)

# _capture_ubatches, behavior unchanged
capture_key_current = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
capture_key_expected = sum(metadata.num_tokens for metadata in ubatch_metadata)
logger.warning(
    "UBATCH_DEBUG capture len_ubatches=%s ubatch_metadata_tokens=%s capture_key_current=%s capture_key_expected=%s mismatch=%s",
    len(ubatch_metadata),
    [metadata.num_tokens for metadata in ubatch_metadata],
    capture_key_current,
    capture_key_expected,
    capture_key_current != capture_key_expected,
)
num_tokens = capture_key_current

---

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=8 capture_key_expected=12 mismatch=True
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [8]}

---

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[8] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=8 capture_key_expected=12 mismatch=True
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [8]}

---

UBATCH_DEBUG lookup lookup_key=8 ubatch_slice_tokens=[3, 3, 2] cache_keys=[8] cudagraph_runtime_mode=FULL
HARNESS_RESULT {'total_tokens': 8, 'token_counts': [3, 3, 2], 'output_shape': (12, 1), 'cache_keys': [8]}

---

sum(metadata.num_tokens for metadata in ubatch_metadata)

---

- num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
+ num_tokens = sum(metadata.num_tokens for metadata in ubatch_metadata)

---

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=12 capture_key_expected=12 mismatch=False
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [12]}

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[12] cudagraph_runtime_mode=FULL
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [12]}

UBATCH_DEBUG lookup lookup_key=8 ubatch_slice_tokens=[3, 3, 2] cache_keys=[12] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[3, 3, 2] capture_key_current=8 capture_key_expected=8 mismatch=False
HARNESS_RESULT {'total_tokens': 8, 'token_counts': [3, 3, 2], 'output_shape': (8, 1), 'cache_keys': [8, 12]}
RAW_BUFFERClick to expand / collapse

Summary

When ubatch_size > 2 and FULL CUDA graph capture is active, UBatchWrapper.__call__() uses the total token count across all ubatches as the graph cache lookup key, but _capture_ubatches() stores the captured graph under the sum of only ubatch_metadata[0] and ubatch_metadata[1]. This causes captured graphs to be stored under the wrong key for 3+ microbatches.

Environment

  • vLLM commit: 12421962955ac28b6f80a0307f554fad939174dd
  • vLLM version: 0.1.dev1+g124219629.d20260519
  • OS: Ubuntu 24.04.4 LTS x86_64
  • Python: 3.12.3
  • PyTorch: 2.11.0+cu130
  • CUDA used to build PyTorch: 13.0
  • NVIDIA driver: 580.126.20
  • GPU: NVIDIA A100-SXM4-40GB

vllm collect-env was run. The A100 was later split into MIG 3g.20gb instances only to run an isolated wrapper harness after the full server path was blocked.

Reproducer

I first tried to reproduce through vllm serve with current main plus logging instrumentation in vllm/v1/worker/gpu_ubatch_wrapper.py.

--ubatch-size is user-visible:

vllm serve --help=all | grep -i ubatch

Single-DP dense model attempt:

VLLM_LOGGING_LEVEL=DEBUG vllm serve facebook/opt-125m \
  --dtype half \
  --max-model-len 1024 \
  --max-num-seqs 64 \
  --max-num-batched-tokens 1024 \
  --ubatch-size 3 \
  --all2all-backend deepep_high_throughput \
  --compilation-config '{"cudagraph_mode":"FULL","cudagraph_capture_sizes":[8,12,16,24,32,48,64]}'

This wrapped the model but did not create ubatch slices on data_parallel_size=1:

ubatch_slices: None, ubatch_slices_padded: None

External DP with a dense model is rejected:

ValueError: Non-MoE models do not support external data parallel mode.

MoE + deepep_high_throughput disables CUDA graphs:

DeepEP: Disabling CUDA Graphs since DeepEP high-throughput kernels are optimized for prefill and are incompatible with CUDA Graphs.

MoE + deepep_low_latency preserved cudagraph_mode=FULL, but the precompiled source environment did not have DeepEP kernels:

AssertionError: DeepEP kernels not found.

To isolate the suspected wrapper behavior, I then used a small in-process harness that exercises the real UBatchWrapper.__call__() and _capture_ubatches() path with three ubatches, real CUDA streams, and real FULL torch.cuda.graph capture. The harness uses a toy CUDA callable and a fake minimal vLLM config with data_parallel_size=2/MoE metadata so the wrapper path is exercised without depending on DeepEP startup.

Instrumentation added:

# __call__
num_tokens = sum(ubatch_slice.num_tokens for ubatch_slice in ubatch_slices)
logger.warning(
    "UBATCH_DEBUG lookup lookup_key=%s ubatch_slice_tokens=%s cache_keys=%s cudagraph_runtime_mode=%s",
    num_tokens,
    [ubatch_slice.num_tokens for ubatch_slice in ubatch_slices],
    list(self.cudagraphs.keys()),
    cudagraph_runtime_mode,
)

# _capture_ubatches, behavior unchanged
capture_key_current = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
capture_key_expected = sum(metadata.num_tokens for metadata in ubatch_metadata)
logger.warning(
    "UBATCH_DEBUG capture len_ubatches=%s ubatch_metadata_tokens=%s capture_key_current=%s capture_key_expected=%s mismatch=%s",
    len(ubatch_metadata),
    [metadata.num_tokens for metadata in ubatch_metadata],
    capture_key_current,
    capture_key_expected,
    capture_key_current != capture_key_expected,
)
num_tokens = capture_key_current

Observed Behavior

For a three-ubatch shape [4, 4, 4], lookup uses total 12, but capture stores under 8:

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=8 capture_key_expected=12 mismatch=True
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [8]}

The same real shape recaptures because key 12 is still absent:

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[8] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=8 capture_key_expected=12 mismatch=True
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [8]}

Then a total-8 shape [3, 3, 2] finds key 8 and replays the graph captured for [4, 4, 4], returning the captured graph's 12-token output shape:

UBATCH_DEBUG lookup lookup_key=8 ubatch_slice_tokens=[3, 3, 2] cache_keys=[8] cudagraph_runtime_mode=FULL
HARNESS_RESULT {'total_tokens': 8, 'token_counts': [3, 3, 2], 'output_shape': (12, 1), 'cache_keys': [8]}

Expected Behavior

The graph should be stored under the same key used by __call__() for lookup. For ubatch_size > 2, that should be:

sum(metadata.num_tokens for metadata in ubatch_metadata)

or a richer key if total token count alone is insufficient to distinguish valid CUDA graph replay shapes.

Suspected Fix

- num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens
+ num_tokens = sum(metadata.num_tokens for metadata in ubatch_metadata)

Also consider asserting that the stored graph metadata layout matches the replay layout if total token count alone is not a safe key.

Controls

Applying the one-line key fix to the same harness stores [4, 4, 4] under key 12; the second total-12 call replays instead of recapturing; total 8 captures a separate key and returns an 8-token output shape.

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[4, 4, 4] capture_key_current=12 capture_key_expected=12 mismatch=False
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [12]}

UBATCH_DEBUG lookup lookup_key=12 ubatch_slice_tokens=[4, 4, 4] cache_keys=[12] cudagraph_runtime_mode=FULL
HARNESS_RESULT {'total_tokens': 12, 'token_counts': [4, 4, 4], 'output_shape': (12, 1), 'cache_keys': [12]}

UBATCH_DEBUG lookup lookup_key=8 ubatch_slice_tokens=[3, 3, 2] cache_keys=[12] cudagraph_runtime_mode=FULL
UBATCH_DEBUG capture len_ubatches=3 ubatch_metadata_tokens=[3, 3, 2] capture_key_current=8 capture_key_expected=8 mismatch=False
HARNESS_RESULT {'total_tokens': 8, 'token_counts': [3, 3, 2], 'output_shape': (8, 1), 'cache_keys': [8, 12]}

ubatch_size=2 is not affected by this specific arithmetic bug because the sum of the first two ubatches equals the sum of all ubatches.

Eager/no-FULL mode does not exercise this graph cache path.

Duplicate Search

Searched GitHub issues and PRs in vllm-project/vllm for:

  • ubatch_metadata num_tokens cudagraph ubatch_size
  • ubatch_size cudagraph num_tokens
  • ubatch_metadata[0] ubatch_metadata[1]
  • UBatchWrapper _capture_ubatches num_tokens

No direct duplicate was found.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING