vllm - ✅(Solved) Fix [Bug]: Qwen3.5-35B-A3B compile cache miss 100% on subgraphs. [1 pull requests, 5 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#37919Fetched 2026-04-08 01:22:40
View on GitHub
Comments
5
Participants
2
Timeline
14
Reactions
0
Author
Participants
Timeline (top)
commented ×5labeled ×2mentioned ×2subscribed ×2

PR fix notes

PR #37901: [compile] Add a fast serializer for aot save/load.

Description (problem / solution / changelog)

Summary:

In AOT compialtion mode we used to serialize FX graphs using GraphPickler which by design is slow for our use case because it tried to save all the metadata including fake tensors and symbolic ints. For our purpose we don't need anything other than nodes and submodules, so in this case we could get away with serializing metadata by implementing our own serializer which is liter and faster. As an impl detail we use JSON but this can be changed to anything.

Test Plan:

pytest tests/compile

Reviewers:

Subscribers:

Tasks:

Tags:

Signed-off-by: zhxchen17 [email protected]

Purpose

Test Plan

Test Result

Before this change (w/ torch 2.12), torch.compile warm start time:

Model                   Cold Start (s)  Warm Start Avg (s)
----------------------  --------------  ------------------
DeepSeek-V3.2           63.92           6.06
GLM-4.7-FP8             61.43           7.55
gpt-oss-120b            14.63           1.95
Llama-3.3-70B-Instruct  26.01           3.90
Qwen3.5-35B-A3B         60.78           3.47

After the change (w/ torch 2.12), torch.compile warm start time:

Model                   Cold Start (s)  Warm Start Avg (s)
----------------------  --------------  ------------------
DeepSeek-V3.2           53.66           1.18 (5.14x)
GLM-4.7-FP8             49.80           1.52 (4.97x)
gpt-oss-120b            11.89           0.41 (4.76x)
Llama-3.3-70B-Instruct  19.62           0.76 (5.13x)
Qwen3.5-35B-A3B         57.86           2.40 (1.45x)

Note that Qwen3.5 has a regression on trunk tracked in https://github.com/vllm-project/vllm/issues/37919 so in theory it should be under 1 sec in warm start but we will fix it separately.

Testing script here: https://github.com/zhxchen17/scripts/blob/main/vllm/compile_time_bench.py


<details> <summary> Essential Elements of an Effective PR Description Checklist </summary>
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
</details>

Changed files

  • tests/compile/test_aot_compile.py (modified, +1/-1)
  • vllm/compilation/caching.py (modified, +28/-12)
  • vllm/compilation/graph_serialization.py (added, +230/-0)

Code Example

vllm serve Qwen/Qwen3.5-35B-A3B

---

mul_6 = mul_5 * add_7;  mul_5 = add_7 = None
    to_2 = mul_6.to(torch.bfloat16);  mul_6 = None
    empty_like = torch.empty_like(to_2)
    gdn_in_proj = torch.ops.vllm.gdn_in_proj(to_2, 12288, 64, 'language_model.model.layers.10.linear_attn');  to_2 = None
    getitem_4 = gdn_in_proj[0]
    getitem_5 = gdn_in_proj[1];  gdn_in_proj = None
    split = getitem_4.split([8192, 4096], dim = -1);  getitem_4 = None
    getitem_6 = split[0]
    getitem_7 = split[1];  split = None
    reshape_3 = getitem_7.reshape(s18, -1, 128);  getitem_7 = None
    sym_size_int_21 = torch.ops.aten.sym_size.int(reshape_3, 0)
    chunk = getitem_5.chunk(2, dim = -1);  getitem_5 = None
    getitem_8 = chunk[0]
    getitem_9 = chunk[1];  chunk = None
    contiguous = getitem_8.contiguous();  getitem_8 = None
    contiguous_1 = getitem_9.contiguous();  getitem_9 = None
    zeros = torch.zeros((s18, 32, 128), dtype = torch.bfloat16, device = device(type='cuda', index=0));  s18 = None
    return (getitem_6, contiguous, contiguous_1, zeros, reshape_3, sym_size_int_21, empty_like, add_5)

---

mul_6 = mul_5 * add_7;  mul_5 = add_7 = None
    to_2 = mul_6.to(torch.bfloat16);  mul_6 = None
    empty_like = torch.empty_like(to_2)
    gdn_in_proj = torch.ops.vllm.gdn_in_proj(to_2, 12288, 64, 'language_model.model.layers.13.linear_attn');  to_2 = None
    getitem_4 = gdn_in_proj[0]
    getitem_5 = gdn_in_proj[1];  gdn_in_proj = None
    split = getitem_4.split([8192, 4096], dim = -1);  getitem_4 = None
    getitem_6 = split[0]
    getitem_7 = split[1];  split = None
    reshape_3 = getitem_7.reshape(s18, -1, 128);  getitem_7 = None
    sym_size_int_28 = torch.ops.aten.sym_size.int(reshape_3, 0)
    chunk = getitem_5.chunk(2, dim = -1);  getitem_5 = None
    getitem_8 = chunk[0]
    getitem_9 = chunk[1];  chunk = None
    contiguous = getitem_8.contiguous();  getitem_8 = None
    contiguous_1 = getitem_9.contiguous();  getitem_9 = None
    zeros = torch.zeros((s18, 32, 128), dtype = torch.bfloat16, device = device(type='cuda', index=0));  s18 = None
    return (getitem_6, contiguous, contiguous_1, zeros, reshape_3, sym_size_int_28, empty_like, add_5)
RAW_BUFFERClick to expand / collapse

Your current environment

vllm trunk + torch trunk

🐛 Describe the bug

repro

vllm serve Qwen/Qwen3.5-35B-A3B

On vllm trunk I saw the cold compilation time of Qwen3.5-35B-A3B went from 25sec to 60sec. So I took a deeper look and found that we are just recompiling all subgraphs instead of only compiling 3 distinct subgraph like usual case. Graph 1 will look like:

    mul_6 = mul_5 * add_7;  mul_5 = add_7 = None
    to_2 = mul_6.to(torch.bfloat16);  mul_6 = None
    empty_like = torch.empty_like(to_2)
    gdn_in_proj = torch.ops.vllm.gdn_in_proj(to_2, 12288, 64, 'language_model.model.layers.10.linear_attn');  to_2 = None
    getitem_4 = gdn_in_proj[0]
    getitem_5 = gdn_in_proj[1];  gdn_in_proj = None
    split = getitem_4.split([8192, 4096], dim = -1);  getitem_4 = None
    getitem_6 = split[0]
    getitem_7 = split[1];  split = None
    reshape_3 = getitem_7.reshape(s18, -1, 128);  getitem_7 = None
    sym_size_int_21 = torch.ops.aten.sym_size.int(reshape_3, 0)
    chunk = getitem_5.chunk(2, dim = -1);  getitem_5 = None
    getitem_8 = chunk[0]
    getitem_9 = chunk[1];  chunk = None
    contiguous = getitem_8.contiguous();  getitem_8 = None
    contiguous_1 = getitem_9.contiguous();  getitem_9 = None
    zeros = torch.zeros((s18, 32, 128), dtype = torch.bfloat16, device = device(type='cuda', index=0));  s18 = None
    return (getitem_6, contiguous, contiguous_1, zeros, reshape_3, sym_size_int_21, empty_like, add_5)

And Graph 2:

    mul_6 = mul_5 * add_7;  mul_5 = add_7 = None
    to_2 = mul_6.to(torch.bfloat16);  mul_6 = None
    empty_like = torch.empty_like(to_2)
    gdn_in_proj = torch.ops.vllm.gdn_in_proj(to_2, 12288, 64, 'language_model.model.layers.13.linear_attn');  to_2 = None
    getitem_4 = gdn_in_proj[0]
    getitem_5 = gdn_in_proj[1];  gdn_in_proj = None
    split = getitem_4.split([8192, 4096], dim = -1);  getitem_4 = None
    getitem_6 = split[0]
    getitem_7 = split[1];  split = None
    reshape_3 = getitem_7.reshape(s18, -1, 128);  getitem_7 = None
    sym_size_int_28 = torch.ops.aten.sym_size.int(reshape_3, 0)
    chunk = getitem_5.chunk(2, dim = -1);  getitem_5 = None
    getitem_8 = chunk[0]
    getitem_9 = chunk[1];  chunk = None
    contiguous = getitem_8.contiguous();  getitem_8 = None
    contiguous_1 = getitem_9.contiguous();  getitem_9 = None
    zeros = torch.zeros((s18, 32, 128), dtype = torch.bfloat16, device = device(type='cuda', index=0));  s18 = None
    return (getitem_6, contiguous, contiguous_1, zeros, reshape_3, sym_size_int_28, empty_like, add_5)

So seems like we have a baked in string arg in call like torch.ops.vllm.gdn_in_proj(to_2, 12288, 64, 'language_model.model.layers.13.linear_attn') which causes each subgraph to look different resulting in compiler cache miss.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

Fix Plan

To address the issue of recompiling all subgraphs instead of only compiling 3 distinct subgraphs, we need to modify the torch.ops.vllm.gdn_in_proj call to avoid using a baked-in string argument that causes each subgraph to look different.

Here are the steps:

  • Identify the torch.ops.vllm.gdn_in_proj calls in your code and extract the string argument.
  • Create a dictionary or a data structure to store the string arguments and their corresponding indices or identifiers.
  • Modify the torch.ops.vllm.gdn_in_proj calls to use the indices or identifiers instead of the string arguments.
  • Use the indices or identifiers to retrieve the corresponding string arguments when needed.

Example code:

# Create a dictionary to store the string arguments
string_args = {
    10: 'language_model.model.layers.10.linear_attn',
    13: 'language_model.model.layers.13.linear_attn'
}

# Modify the torch.ops.vllm.gdn_in_proj calls
def gdn_in_proj(to_2, num, dim, layer_idx):
    string_arg = string_args[layer_idx]
    return torch.ops.vllm.gdn_in_proj(to_2, num, dim, string_arg)

# Usage
gdn_in_proj(to_2, 12288, 64, 10)
gdn_in_proj(to_2, 12288, 64, 13)

Verification

To verify that the fix worked, you can check the compilation time of the Qwen3.5-35B-A3B model and ensure that it has returned to the expected value of around 25 seconds. You can also use debugging tools to verify that the torch.ops.vllm.gdn_in_proj calls are using the correct string arguments and that the compiler cache is being utilized correctly.

Extra Tips

  • Make sure to update the string_args dictionary whenever new string arguments are added or modified.
  • Consider using a more robust data structure, such as a database or a configuration file, to store the string arguments and their corresponding indices or identifiers.
  • Use debugging tools and logging to monitor the performance of the torch.ops.vllm.gdn_in_proj calls and the compiler cache utilization.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING