vllm - ✅(Solved) Fix [Bug]: Gemma 4 (31B/26B-A4B) vision outputs only <pad> under fp16 — vision_tower standardize overflows [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
vllm-project/vllm#40290Fetched 2026-04-20 11:59:34
View on GitHub
Comments
0
Participants
1
Timeline
3
Reactions
0
Timeline (top)
cross-referenced ×1referenced ×1subscribed ×1

Root Cause

In HF transformers modeling_gemma4.py (tested against transformers==4.X), Gemma4VisionModel.forward ends with:

# modeling_gemma4.py:1943-1944
if self.config.standardize:
    hidden_states = (hidden_states - self.std_bias) * self.std_scale

Checkpoint weights for gemma-4-31B-it (inspected directly from the safetensors shards):

keydtypeshapemaxminmean
model.vision_tower.std_biasbfloat16(1152,)36352.0-53760.0-36.52
model.vision_tower.std_scalebfloat16(1152,)0.02100.00010.0013

fp16's maximum representable magnitude is 65,504. When hidden_states (shape (num_tokens, 1152), values in ~[-7, +7] per our diagnostic) is subtracted from std_bias elementwise, the intermediate (h - std_bias) already saturates fp16 for the components where std_bias ≈ -53760 and h is negative, or where std_bias ≈ 36352 and h is positive. Some elements overflow to -inf. The subsequent multiply by std_scale can't recover; -inf * small_positive = -inf.

The weights themselves fit in fp16 (53760 < 65504) — storage is fine. It is the arithmetic in fp16 that overflows.

bf16 is immune (max ~3.4e38), and the checkpoint stores these buffers in bf16 precisely because they carry a trained per-channel mean/stddev at the raw SigLIP scale.

Fix Action

Fix / Workaround

============================== CPU Info

Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 24 On-line CPU(s) list: 0-23 Vendor ID: GenuineIntel Model name: Intel(R) Core(TM) Ultra 9 285K CPU family: 6 Model: 198 Thread(s) per core: 1 Core(s) per socket: 24 Socket(s): 1 Stepping: 2 CPU(s) scaling MHz: 111% CPU max MHz: 4600.0000 CPU min MHz: 800.0000 BogoMIPS: 7372.80 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities Virtualization: VT-x L1d cache: 768 KiB (20 instances) L1i cache: 1.3 MiB (20 instances) L2 cache: 40 MiB (12 instances) L3 cache: 36 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-23 Vulnerability Gather data sampling: Not affected Vulnerability Ghostwrite: Not affected Vulnerability Indirect target selection: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Old microcode: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S Vulnerability Srbds: Not affected Vulnerability Tsa: Not affected Vulnerability Tsx async abort: Not affected Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

I added a small diagnostic patch to vllm/model_executor/models/gemma4_mm.py that logs tensor stats at three points in the image pipeline: after vision_tower(...), inside Gemma4MultimodalEmbedder.forward, and after _merge_multimodal_embeddings. Full diff available on request. Logs from a single image request:

PR fix notes

PR #40185: gemma4: Enable mm prefix-lm masking for vision bidirectional attention

Description (problem / solution / changelog)

Purpose

Fix Gemma4 multimodal attention masking when the HF text config enables vision-only bidirectional attention (use_bidirectional_attention="vision").

This PR sets hf_config.is_mm_prefix_lm = True during Gemma4Config.verify_and_update_config when use_bidirectional_attention == "vision", which activates vLLM’s existing multimodal prefix-LM masking path (via mm_prefix_range) so that vision tokens can attend bidirectionally within the multimodal prefix while preserving causal behavior elsewhere.

Fixes: https://github.com/vllm-project/vllm/issues/40106

Test Plan

Run the unit test: pytest -q --noconftest tests/test_gemma4_mm_prefix_lm.py

Test Result

Ran in Docker (clean env): docker run --rm -v ${PWD}:/repo -w /repo python:3.10-slim bash -lc "python -m pip install -q --upgrade pip && pip install -q pytest numpy && pip install -q torch --index-url https://download.pytorch.org/whl/cpu && PYTHONPATH=/repo pytest -q --noconftest tests/test_gemma4_mm_prefix_lm.py"

Output: 2 passed, 1 warning in 7.13s

Changed files

  • tests/test_gemma4_mm_prefix_lm.py (added, +45/-0)
  • vllm/model_executor/models/config.py (modified, +7/-0)

Code Example

Collecting environment information...
uv is set
==============================
        System Info
==============================
OS                           : Ubuntu 24.04.3 LTS (x86_64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version                : Could not collect
CMake version                : version 3.28.3
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cu128
Is debug build               : False
CUDA used to build PyTorch   : 12.8
ROCM used to build PyTorch   : N/A
XPU used to build PyTorch    : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.3 (main, Mar  3 2026, 12:15:18) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-6.17.0-19-generic-x86_64-with-glibc2.39
    
==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 12.0.140
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version        : 580.126.09
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  24
On-line CPU(s) list:                     0-23
Vendor ID:                               GenuineIntel
Model name:                              Intel(R) Core(TM) Ultra 9 285K
CPU family:                              6
Model:                                   198
Thread(s) per core:                      1
Core(s) per socket:                      24
Socket(s):                               1
Stepping:                                2
CPU(s) scaling MHz:                      111%
CPU max MHz:                             4600.0000
CPU min MHz:                             800.0000
BogoMIPS:                                7372.80
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                          VT-x
L1d cache:                               768 KiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                40 MiB (12 instances)
L3 cache:                                36 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-23
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

==============================
Versions of relevant libraries
==============================
[pip3] No relevant packages
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.19.0
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology:
  	GPU0	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	0-23	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

==============================
     Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_wenqiangli

---

import base64, json, urllib.request
with open('test.jpg', 'rb') as f:
    b64 = base64.b64encode(f.read()).decode()

payload = {
    "model": "/path/to/gemma4-31b-awq",
    "messages": [{"role": "user", "content": [
        {"type": "image_url",
         "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
        {"type": "text", "text": "What is in this image?"}
    ]}],
    "max_tokens": 50, "temperature": 0.1,
    "skip_special_tokens": False,  # needed to see pad tokens
}
r = urllib.request.urlopen(urllib.request.Request(
    "http://localhost:8000/v1/chat/completions",
    data=json.dumps(payload).encode(),
    headers={"Content-Type": "application/json"}
))
print(json.loads(r.read())['choices'][0]['message']['content'])

---

[diag vt] pixel_values (input)
    shape=(1, 2520, 768) dtype=torch.float16
    norm=768.000 mean=0.30476 std=0.46031 max=1.000 min=0.000 nan=False inf=False
[diag vt] vision_tower.last_hidden_state
    shape=(256, 1152) dtype=torch.float16
    norm=inf mean=-inf std=nan max=7.238 min=-inf nan=False inf=TrueOVERFLOW
[diag embed] inputs_embeds
    shape=(1, 256, 1152) dtype=torch.float16
    norm=inf mean=-inf std=nan max=7.238 min=-inf nan=False inf=True
[diag embed] projection.weight
    shape=(5376, 1152) dtype=torch.float16
    norm=80.440 mean=-0.00000 std=0.03232 max=0.459 min=-0.496 nan=False inf=False
[diag embed] after_projection
    shape=(1, 256, 5376) dtype=torch.float16
    norm=inf mean=nan std=nan max=inf min=-inf nan=False inf=TrueInf × finite = Inf
[diag embed] after_norm (final)
    shape=(1, 256, 5376) dtype=torch.float16
    norm=nan mean=nan std=nan max=nan min=nan nan=True inf=FalseRMSNorm(Inf) = NaN
[diag merge] merged[mm positions]
    shape=(266, 5376) dtype=torch.float16
    norm=nan mean=nan std=nan max=nan min=nan nan=True inf=FalseLM sees NaN at image positions
[diag merge] merged[text positions]
    shape=(22, 5376) dtype=torch.float16
    norm=381.840 mean=0.00472 std=1.11030 max=60.438 min=-15.320 nan=False inf=False

---

# modeling_gemma4.py:1943-1944
if self.config.standardize:
    hidden_states = (hidden_states - self.std_bias) * self.std_scale

---

# ---- Vision tower (shared by image and video) ----
with self._mark_tower_model(vllm_config, {"image", "video"}):
    self.vision_tower = AutoModel.from_config(config=config.vision_config)
    self.embed_vision = Gemma4MultimodalEmbedder(
        config.vision_config, config.text_config
    )

---

with self._mark_tower_model(vllm_config, {"image", "video"}):
    self.vision_tower = AutoModel.from_config(config=config.vision_config)
    if getattr(config.vision_config, "standardize", False):
        # SigLIP standardize uses (h - std_bias) * std_scale with
        # |std_bias| ~5.4e4, which overflows fp16 (max ±6.55e4).
        # Keep vision_tower in bf16 so the arithmetic is safe; the
        # projector downstream already casts back to the LM dtype.
        self.vision_tower = self.vision_tower.to(torch.bfloat16)
    self.embed_vision = Gemma4MultimodalEmbedder(
        config.vision_config, config.text_config
    )

---

vt_dtype = next(vt.parameters()).dtype
for i in range(pixel_values.shape[0]):
    pv = pixel_values[i].unsqueeze(0).to(vt_dtype)
    pp = pixel_position_ids[i].unsqueeze(0)
    ...

---

merged[mm positions]   norm=1195.832 max=14.000 min=-14.594  nan=False inf=False
merged[text positions] norm=438.421  max=60.438 min=-15.320  nan=False inf=False
RAW_BUFFERClick to expand / collapse

Your current environment

<details> <summary>The output of <code>python collect_env.py</code></summary>
Collecting environment information...
uv is set
==============================
        System Info
==============================
OS                           : Ubuntu 24.04.3 LTS (x86_64)
GCC version                  : (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version                : Could not collect
CMake version                : version 3.28.3
Libc version                 : glibc-2.39

==============================
       PyTorch Info
==============================
PyTorch version              : 2.10.0+cu128
Is debug build               : False
CUDA used to build PyTorch   : 12.8
ROCM used to build PyTorch   : N/A
XPU used to build PyTorch    : N/A

==============================
      Python Environment
==============================
Python version               : 3.12.3 (main, Mar  3 2026, 12:15:18) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-6.17.0-19-generic-x86_64-with-glibc2.39
    
==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 12.0.140
CUDA_MODULE_LOADING set to   : 
GPU models and configuration : GPU 0: NVIDIA GeForce RTX 5090
Nvidia driver version        : 580.126.09
cuDNN version                : Could not collect
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                            x86_64
CPU op-mode(s):                          32-bit, 64-bit
Address sizes:                           46 bits physical, 48 bits virtual
Byte Order:                              Little Endian
CPU(s):                                  24
On-line CPU(s) list:                     0-23
Vendor ID:                               GenuineIntel
Model name:                              Intel(R) Core(TM) Ultra 9 285K
CPU family:                              6
Model:                                   198
Thread(s) per core:                      1
Core(s) per socket:                      24
Socket(s):                               1
Stepping:                                2
CPU(s) scaling MHz:                      111%
CPU max MHz:                             4600.0000
CPU min MHz:                             800.0000
BogoMIPS:                                7372.80
Flags:                                   fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect user_shstk avx_vnni lam wbnoinvd dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq rdpid bus_lock_detect movdiri movdir64b fsrm md_clear serialize arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                          VT-x
L1d cache:                               768 KiB (20 instances)
L1i cache:                               1.3 MiB (20 instances)
L2 cache:                                40 MiB (12 instances)
L3 cache:                                36 MiB (1 instance)
NUMA node(s):                            1
NUMA node0 CPU(s):                       0-23
Vulnerability Gather data sampling:      Not affected
Vulnerability Ghostwrite:                Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit:             Not affected
Vulnerability L1tf:                      Not affected
Vulnerability Mds:                       Not affected
Vulnerability Meltdown:                  Not affected
Vulnerability Mmio stale data:           Not affected
Vulnerability Old microcode:             Not affected
Vulnerability Reg file data sampling:    Not affected
Vulnerability Retbleed:                  Not affected
Vulnerability Spec rstack overflow:      Not affected
Vulnerability Spec store bypass:         Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:                Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:                Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS Not affected; BHI BHI_DIS_S
Vulnerability Srbds:                     Not affected
Vulnerability Tsa:                       Not affected
Vulnerability Tsx async abort:           Not affected
Vulnerability Vmscape:                   Mitigation; IBPB before exit to userspace

==============================
Versions of relevant libraries
==============================
[pip3] No relevant packages
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.19.0
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled; XPU: Disabled
GPU Topology:
  	GPU0	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	0-23	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

==============================
     Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_wenqiangli
</details>

🐛 Describe the bug

Related issues (not duplicates)

  • #40095Gemma4MultimodalEmbedder norm/linear order swap causes audio ASR hallucination. Same file, different layer and different symptom; both can coexist.
  • #40106 / PR #40185 — missing prefix-LM mask for use_bidirectional_attention="vision". Confirmed not the cause of this pad-token bug (details in Ruled-out hypotheses below).
  • #40247 / #40286 — Gemma 4 26B-A4B AWQ fails to load on v0.19.1. This bug loads fine (engine starts, vision_tower instantiated) and fails at inference; different surface.

TL;DR

vLLM's Gemma 4 multimodal loader casts vision_tower to the engine dtype. With --dtype float16 (default for AWQ), the SigLIP final standardization step (h - std_bias) * std_scale overflows fp16 because |std_bias| reaches ~5.4e4 (fp16 max is 6.55e4). vision_tower.last_hidden_state becomes -inf, which propagates to NaN after the multimodal RMSNorm, and the language model samples <pad> tokens for every image request.

The weight is stored in bf16 in the checkpoint and has a range that bf16 can hold but fp16 cannot. Keeping vision_tower at bf16 (or fp32) while the rest of the engine runs fp16 fixes the bug.

Affects gemma-4-31B-it and gemma-4-26B-A4B-it whenever the engine runs at dtype=float16 (i.e. AWQ or explicit --dtype float16). gemma-4-E4B-it is not affected because its vision config has standardize=False.

Reproduction

Minimal, direct POST to /v1/chat/completions:

import base64, json, urllib.request
with open('test.jpg', 'rb') as f:
    b64 = base64.b64encode(f.read()).decode()

payload = {
    "model": "/path/to/gemma4-31b-awq",
    "messages": [{"role": "user", "content": [
        {"type": "image_url",
         "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
        {"type": "text", "text": "What is in this image?"}
    ]}],
    "max_tokens": 50, "temperature": 0.1,
    "skip_special_tokens": False,  # needed to see pad tokens
}
r = urllib.request.urlopen(urllib.request.Request(
    "http://localhost:8000/v1/chat/completions",
    data=json.dumps(payload).encode(),
    headers={"Content-Type": "application/json"}
))
print(json.loads(r.read())['choices'][0]['message']['content'])

Observed: '<pad><pad><pad>...' × 50 (finish_reason: "length"). Expected: a description of the image.

Text-only requests to the same server work fine (control test returns "hello world"), so the regression is isolated to the image path.

Evidence: vision_tower output is -inf

I added a small diagnostic patch to vllm/model_executor/models/gemma4_mm.py that logs tensor stats at three points in the image pipeline: after vision_tower(...), inside Gemma4MultimodalEmbedder.forward, and after _merge_multimodal_embeddings. Full diff available on request. Logs from a single image request:

[diag vt] pixel_values (input)
    shape=(1, 2520, 768) dtype=torch.float16
    norm=768.000 mean=0.30476 std=0.46031 max=1.000 min=0.000 nan=False inf=False
[diag vt] vision_tower.last_hidden_state
    shape=(256, 1152) dtype=torch.float16
    norm=inf mean=-inf std=nan max=7.238 min=-inf nan=False inf=True          ← OVERFLOW
[diag embed] inputs_embeds
    shape=(1, 256, 1152) dtype=torch.float16
    norm=inf mean=-inf std=nan max=7.238 min=-inf nan=False inf=True
[diag embed] projection.weight
    shape=(5376, 1152) dtype=torch.float16
    norm=80.440 mean=-0.00000 std=0.03232 max=0.459 min=-0.496 nan=False inf=False
[diag embed] after_projection
    shape=(1, 256, 5376) dtype=torch.float16
    norm=inf mean=nan std=nan max=inf min=-inf nan=False inf=True             ← Inf × finite = Inf
[diag embed] after_norm (final)
    shape=(1, 256, 5376) dtype=torch.float16
    norm=nan mean=nan std=nan max=nan min=nan nan=True inf=False              ← RMSNorm(Inf) = NaN
[diag merge] merged[mm positions]
    shape=(266, 5376) dtype=torch.float16
    norm=nan mean=nan std=nan max=nan min=nan nan=True inf=False              ← LM sees NaN at image positions
[diag merge] merged[text positions]
    shape=(22, 5376) dtype=torch.float16
    norm=381.840 mean=0.00472 std=1.11030 max=60.438 min=-15.320 nan=False inf=False

Pixel values enter fine. The vision_tower's final output is -inf. That propagates through the multimodal projector and RMSNorm to NaN at every image position in the merged embeddings, so the language model's first layer receives NaN for 266 of 288 token positions.

Root cause

In HF transformers modeling_gemma4.py (tested against transformers==4.X), Gemma4VisionModel.forward ends with:

# modeling_gemma4.py:1943-1944
if self.config.standardize:
    hidden_states = (hidden_states - self.std_bias) * self.std_scale

Checkpoint weights for gemma-4-31B-it (inspected directly from the safetensors shards):

keydtypeshapemaxminmean
model.vision_tower.std_biasbfloat16(1152,)36352.0-53760.0-36.52
model.vision_tower.std_scalebfloat16(1152,)0.02100.00010.0013

fp16's maximum representable magnitude is 65,504. When hidden_states (shape (num_tokens, 1152), values in ~[-7, +7] per our diagnostic) is subtracted from std_bias elementwise, the intermediate (h - std_bias) already saturates fp16 for the components where std_bias ≈ -53760 and h is negative, or where std_bias ≈ 36352 and h is positive. Some elements overflow to -inf. The subsequent multiply by std_scale can't recover; -inf * small_positive = -inf.

The weights themselves fit in fp16 (53760 < 65504) — storage is fine. It is the arithmetic in fp16 that overflows.

bf16 is immune (max ~3.4e38), and the checkpoint stores these buffers in bf16 precisely because they carry a trained per-channel mean/stddev at the raw SigLIP scale.

Why only 31B / 26B-A4B are affected

Only some Gemma 4 variants set vision_config.standardize = True:

  • gemma-4-31B-it → standardize=True → affected
  • gemma-4-26B-A4B-it → standardize=True → affected (expected; not personally reproduced here)
  • gemma-4-E4B-it → standardize=False → the overflowing branch is skipped, no bug

Proposed fix

Keep vision_tower at bf16 regardless of the engine dtype, and cast its input to match. The current code at vllm/model_executor/models/gemma4_mm.py:~888:

# ---- Vision tower (shared by image and video) ----
with self._mark_tower_model(vllm_config, {"image", "video"}):
    self.vision_tower = AutoModel.from_config(config=config.vision_config)
    self.embed_vision = Gemma4MultimodalEmbedder(
        config.vision_config, config.text_config
    )

becomes:

with self._mark_tower_model(vllm_config, {"image", "video"}):
    self.vision_tower = AutoModel.from_config(config=config.vision_config)
    if getattr(config.vision_config, "standardize", False):
        # SigLIP standardize uses (h - std_bias) * std_scale with
        # |std_bias| ~5.4e4, which overflows fp16 (max ±6.55e4).
        # Keep vision_tower in bf16 so the arithmetic is safe; the
        # projector downstream already casts back to the LM dtype.
        self.vision_tower = self.vision_tower.to(torch.bfloat16)
    self.embed_vision = Gemma4MultimodalEmbedder(
        config.vision_config, config.text_config
    )

And in _process_image_input (~line 1057), cast pixel_values to the vision_tower's dtype before the call:

vt_dtype = next(vt.parameters()).dtype
for i in range(pixel_values.shape[0]):
    pv = pixel_values[i].unsqueeze(0).to(vt_dtype)
    pp = pixel_position_ids[i].unsqueeze(0)
    ...

The existing target_dtype = self.embed_vision.embedding_projection.weight.dtype cast at line ~1085 already handles the bf16 → fp16 downcast for the projector, so no downstream changes are required.

Fix verification

After applying both edits and restarting vLLM (same server, same prompt, same image):

  • image 1 (2048×1536 paint cans): 32 tokens, 0.8s, finish="stop""This image shows three cans of paint sitting on a wooden shelf. The cans are from the brands Zinsser, Farrow & Ball, and Dulux."
  • image 2 (same scene, different prompt asking for OCR): 200 tokens, 3.0s, detailed transcription of all three can labels

Post-fix diagnostic confirms the downstream tensors are clean:

merged[mm positions]   norm=1195.832 max=14.000 min=-14.594  nan=False inf=False
merged[text positions] norm=438.421  max=60.438 min=-15.320  nan=False inf=False

Ruled-out hypotheses

  1. Missing prefix-LM attention mask (the use_bidirectional_attention="vision" config flag). This was the suspicion in issue #40106 and PR #40185. Applying the PR's is_mm_prefix_lm = True plumbing did not change the pad-token output at all. The KL divergence reported in #40106 (0.03–0.09 vs HF) is far too small to explain pure <pad> output; that observation is consistent with the prefix-LM difference being a real but minor accuracy hit on top of a working model.

  2. Missing checkpoint weight embed_vision.embedding_post_projection_norm.weight. Checked: vLLM's Gemma4MultimodalEmbedder declares this RMSNorm with has_weight=False, so there is no learnable scale and no weight is expected in the checkpoint. Architecture comment at gemma4_mm.py:801-803 confirms.

  3. AWQ quantization of vision weights. The checkpoint's quantization_config.modules_to_not_convert lists vision_tower, so those weights are stored unquantized (plain .weight tensors, bf16). The .linear.weight suffix on vision layers is just the AWQ packaging convention; the weights load correctly and have sane norms (verified by opening safetensors directly).

  4. bf16→fp16 precision loss on the std_bias storage itself. Checked: 53760 is exactly representable in fp16 (below the 65504 max). Storage is fine; the overflow is at runtime during h - std_bias.

Blast radius

Anyone running Gemma 4 31B or 26B-A4B on vLLM with --dtype float16 (including everyone using AWQ/GPTQ quantizations of the 31B, since those force fp16). The bug is silent from vLLM's perspective — no warning, no NaN check — so users see pad-token output with no log signal.

References

  • Issue #40106 — original report about missing prefix-LM (related but NOT this bug)
  • PR #40185 — prefix-LM fix, requested-closed; not required
  • HF source: transformers/models/gemma4/modeling_gemma4.py:1943-1944 (Gemma4VisionModel.forward standardize block)
  • HF source: transformers/models/gemma4/modeling_gemma4.py:1899-1901 (std_bias / std_scale buffer registration)

extent analysis

TL;DR

The issue can be fixed by keeping the vision_tower at bf16 dtype, regardless of the engine dtype, to prevent overflow during the standardization step.

Guidance

  • Identify the vision_tower model and its dtype in the code.
  • Check if the standardize flag is set to True in the vision_config.
  • If standardize is True, keep the vision_tower at bf16 dtype to prevent overflow.
  • Cast the input to the vision_tower to match its dtype.
  • Verify that the fix works by checking the output of the vision_tower and the downstream tensors.

Example

The proposed fix involves modifying the code to keep the vision_tower at bf16 dtype:

if getattr(config.vision_config, "standardize", False):
    self.vision_tower = self.vision_tower.to(torch.bfloat16)

And casting the input to the vision_tower to match its dtype:

vt_dtype = next(vt.parameters()).dtype
for i in range(pixel_values.shape[0]):
    pv = pixel_values[i].unsqueeze(0).to(vt_dtype)
    pp = pixel_position_ids[i].unsqueeze(0)
    ...

Notes

The fix only applies to Gemma 4 models with standardize=True in the vision_config, which includes the 31B and 26B-A4B variants. The fix does not affect the 4B-E4B variant, which has standardize=False.

Recommendation

Apply the proposed fix to keep the vision_tower at bf16 dtype, as it prevents overflow during the standardization step and fixes the issue.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING