vllm - ✅(Solved) Fix [Bug]: VLLM Gemma4 output repeated token [1 pull requests, 1 comments, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#39827Fetched 2026-04-16 06:36:22
View on GitHub
Comments
1
Participants
1
Timeline
8
Reactions
0
Participants
Timeline (top)
closed ×2commented ×1cross-referenced ×1labeled ×1

Fix Action

Fixed

PR fix notes

PR #39842: [Model] Fix Gemma 4 token repetition by dynamic BOS injection for PT models

Description (problem / solution / changelog)

Purpose

This PR fixes the token repetition issue (e.g., "is is is is...") observed in Gemma 4 Pre-Trained (PT) models when running in completion mode (without a chat template).

The issue was caused by Gemma4ProcessingInfo.get_default_tok_params explicitly overriding add_special_tokens=False to prevent double-BOS sequences in Instruction-Tuned (IT) models (where the chat template already includes a literal <bos>). However, this caused PT models loaded without a chat template to lack the required <bos> token at position 0, leading to generation degradation.

This fix dynamically checks for the presence of a chat_template on the tokenizer:

  • For IT models: Keeps add_special_tokens=False to avoid double-BOS artifacts.
  • For PT models: Allows the default add_special_tokens=True to ensure the <bos> token is injected for raw prompts. Fixes https://github.com/vllm-project/vllm/issues/39827

It doesn't affect other models nor Gemma4 IT models - It is targeted for Gemma4 PT models checkpoints only.

Test Plan & Results

The test is really straightforward. Running the following test code:

import os
from vllm import LLM, SamplingParams
VLLM_MODEL_ID = 'google/gemma-4-26B-A4B'
llm = LLM(
  VLLM_MODEL_ID,
  max_model_len=16,
  tensor_parallel_size=1,
  gpu_memory_utilization=0.65,
  async_scheduling=False,
  enforce_eager=True,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=10)
# Hardcoded Gemma 4 prompt template string
prompt = 'Paris is'
print('\n' + '='*80)
print('Generation test after weight transfer:')
print(llm.generate(prompt, sampling_params=sampling_params))

Without the fix the response has a loop as below:

Rendering prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.40it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.77it/s, est. speed input: 15.59 toks/s, output: 77.91 toks/s]
[RequestOutput(request_id=0, prompt='Paris is', prompt_token_ids=[50429, 563], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' is is is is is is is is is is', token_ids=[563, 563, 563, 563, 563, 563, 563, 563, 563, 563], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]
(EngineCore pid=892199) INFO 04-14 21:00:27 [core.py:1210] Shutdown initiated (timeout=0)
(EngineCore pid=892199) INFO 04-14 21:00:27 [core.py:1233] Shutdown complete

with the fix the response is the correct one, without tokens looping:

Rendering prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.47it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.11it/s, est. speed input: 3.33 toks/s, output: 11.12 toks/s]
[RequestOutput(request_id=0, prompt='Paris is', prompt_token_ids=[2, 50429, 563], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' a city of romance, and there’s no', token_ids=[496, 3207, 529, 30875, 236764, 532, 993, 236858, 236751, 951], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]
(EngineCore pid=970951) INFO 04-14 23:22:39 [core.py:1234] Shutdown initiated (timeout=0)
(EngineCore pid=970951) INFO 04-14 23:22:39 [core.py:1257] Shutdown complete

Changed files

  • vllm/model_executor/models/gemma4_mm.py (modified, +7/-2)

Code Example

import gc
import logging
import os
from vllm import LLM, SamplingParams
import jax
from jax import config as jax_config

VLLM_MODEL_ID="google/gemma-4-26B-A4B"
#VLLM_MODEL_ID="Qwen/Qwen3-30B-A3B"

_JAX_COMPILATION_CACHE_DIR = "/tmp/jax_cache"
def _setup_jax_compilation_cache():
  jax_config.update("jax_compilation_cache_dir", _JAX_COMPILATION_CACHE_DIR)
  jax_config.update("jax_persistent_cache_min_entry_size_bytes", -1)
  jax_config.update("jax_persistent_cache_min_compile_time_secs", 0)
  jax_config.update("jax_enable_compilation_cache", True)

def _setup_vllm():
  # for vLLM we can skip JAX precompilation with this flag, it makes startup faster
  os.environ["SKIP_JAX_PRECOMPILE"] = "1"
  os.environ["JAX_RANDOM_WEIGHTS"] = "False"
  os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
  
def _clean_device_memory():
  """Forces Python garbage collection and waits for JAX devices to idle."""
  logging.info("Cleaning JAX device memory...")
  # Run Python's garbage collector to free Python-level references
  gc.collect()
  # Wait for all devices to finish pending operations.
  # This allows JAX to reclaim memory associated with arrays
  # that are no longer referenced.
  for x in jax.live_arrays():
      x.delete()      
  logging.info("Device memory cleanup complete.")  

print(f"JAX devices: {jax.devices()}")  
_setup_jax_compilation_cache()
_setup_vllm()
_clean_device_memory()

llm = LLM(
  VLLM_MODEL_ID,
  max_model_len=16,
  tensor_parallel_size=2,
  data_parallel_size=-1,
  gpu_memory_utilization=0.65,
  async_scheduling=False,
  enforce_eager=True,
)

sampling_params = SamplingParams(temperature=0.0, max_tokens=10)
print("\n" + "="*80)
print("Generation test after weight transfer:")
print(llm.generate("Paris is", sampling_params=sampling_params))
RAW_BUFFERClick to expand / collapse

Your current environment

vllm and tpu-inference are installed following https://docs.vllm.ai/projects/tpu/en/latest/getting_started/installation/#install-from-source

Vllm commit: 08bfedc152f064d8e84f85c4f42b810e5a564229 tpu-inference: 6a488560b005aa71bbc2af6a32a2e9f814c71298 (from Apr 13)

🐛 Describe the bug

Bug: Gemma4 has a repeated output token, resulting in incorrect output

example:

Gemma4 ouput: [RequestOutput(request_id=0, prompt='Paris is', prompt_token_ids=[50429, 563], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' is is is is is is is is is is', token_ids=[563, 563, 563, 563, 563, 563, 563, 563, 563, 563], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]

When changing VLLM model ID to "Qwen/Qwen3-30B-A3B", the output is “the capital of France. It is a city of”

To reproduce the bug, run the following script:

import gc
import logging
import os
from vllm import LLM, SamplingParams
import jax
from jax import config as jax_config

VLLM_MODEL_ID="google/gemma-4-26B-A4B"
#VLLM_MODEL_ID="Qwen/Qwen3-30B-A3B"

_JAX_COMPILATION_CACHE_DIR = "/tmp/jax_cache"
def _setup_jax_compilation_cache():
  jax_config.update("jax_compilation_cache_dir", _JAX_COMPILATION_CACHE_DIR)
  jax_config.update("jax_persistent_cache_min_entry_size_bytes", -1)
  jax_config.update("jax_persistent_cache_min_compile_time_secs", 0)
  jax_config.update("jax_enable_compilation_cache", True)

def _setup_vllm():
  # for vLLM we can skip JAX precompilation with this flag, it makes startup faster
  os.environ["SKIP_JAX_PRECOMPILE"] = "1"
  os.environ["JAX_RANDOM_WEIGHTS"] = "False"
  os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
  
def _clean_device_memory():
  """Forces Python garbage collection and waits for JAX devices to idle."""
  logging.info("Cleaning JAX device memory...")
  # Run Python's garbage collector to free Python-level references
  gc.collect()
  # Wait for all devices to finish pending operations.
  # This allows JAX to reclaim memory associated with arrays
  # that are no longer referenced.
  for x in jax.live_arrays():
      x.delete()      
  logging.info("Device memory cleanup complete.")  

print(f"JAX devices: {jax.devices()}")  
_setup_jax_compilation_cache()
_setup_vllm()
_clean_device_memory()

llm = LLM(
  VLLM_MODEL_ID,
  max_model_len=16,
  tensor_parallel_size=2,
  data_parallel_size=-1,
  gpu_memory_utilization=0.65,
  async_scheduling=False,
  enforce_eager=True,
)

sampling_params = SamplingParams(temperature=0.0, max_tokens=10)
print("\n" + "="*80)
print("Generation test after weight transfer:")
print(llm.generate("Paris is", sampling_params=sampling_params))

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

extent analysis

TL;DR

The issue with repeated output tokens in Gemma4 model can be potentially resolved by adjusting the model or trying a different model like "Qwen/Qwen3-30B-A3B" which does not exhibit this behavior.

Guidance

  • The repeated output token issue seems model-specific, as switching to "Qwen/Qwen3-30B-A3B" resolves the problem, suggesting the issue might be with the "google/gemma-4-26B-A4B" model configuration or training data.
  • To verify if the issue is indeed model-specific, try running the same script with different models to see if the problem persists.
  • Consider checking the documentation or reaching out to the model developers for "google/gemma-4-26B-A4B" to report the issue and seek guidance on potential fixes or workarounds.
  • Review the sampling_params and LLM initialization parameters to ensure they are appropriately set for the model being used, as some models might have specific requirements for these parameters.

Example

No specific code changes are suggested without further investigation, but the provided script can be used as a basis for testing different models and parameters.

Notes

The cause of the repeated token issue is not immediately clear and may require further investigation into the model's training data, configuration, or the interaction with the vllm and tpu-inference libraries.

Recommendation

Apply workaround by using a different model like "Qwen/Qwen3-30B-A3B" until the issue with "google/gemma-4-26B-A4B" is understood and resolved.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

vllm - ✅(Solved) Fix [Bug]: VLLM Gemma4 output repeated token [1 pull requests, 1 comments, 1 participants]