transformers - 💡(How to fix) Fix Gemma4-E2B/E4B: passing `inputs_embeds` triggers an extremely expensive reverse embedding lookup [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45874Fetched 2026-05-11 03:12:57
View on GitHub
Comments
2
Participants
2
Timeline
7
Reactions
0
Timeline (top)
commented ×2mentioned ×2subscribed ×2labeled ×1

Error Message

model_id = "google/gemma-4-E2B-it" gpu_available = torch.cuda.is_available() bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 if (gpu_available and torch.cuda.is_bf16_supported()) else torch.float16, bnb_4bit_use_double_quant=True ) model = AutoModelForImageTextToText.from_pretrained( model_id, quantization_config=bnb_config, device_map={"": i for i in range(torch.cuda.device_count())} if gpu_available else "auto", low_cpu_mem_usage=True, attn_implementation="sdpa" ) processor = AutoProcessor.from_pretrained(model_id)

messages = [ {"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": [{"type": "text", "text": "What are the three laws of thermodynamics?"}]} ] inputs = processor.apply_chat_template( messages, tokenize=True, return_dict=True, return_tensors="pt", add_generation_prompt=True ).to(model.device)

inputs_embeds = model.get_input_embeddings()(inputs["input_ids"])

with torch.no_grad(): outputs = model( # Used under the hood of model.generate inputs_embeds=inputs_embeds, attention_mask=inputs["attention_mask"], use_cache=True, logits_to_keep=1, past_key_values=None )

OutOfMemoryError: CUDA out of memory. Tried to allocate 135.00 GiB.

Root Cause

Gemma 4 E2B/E4B requires input_ids to look up Per-Layer Embeddings (PLE). When only inputs_embeds is provided, get_per_layer_inputs (L1710) attempts to recover input_ids by comparing every embedding vector against every row of the full embedding weight matrix (L1731):

if input_ids is None:
    with torch.no_grad():
        input_ids = (
            (
                inputs_embeds[:, :, None, :]
                == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
            )
            .all(dim=3)
            .nonzero()[:, 2]
        )

This materialises a huge boolean tensor, almost inevitably causing an OOM.

Code Example

if input_ids is None:
    with torch.no_grad():
        input_ids = (
            (
                inputs_embeds[:, :, None, :]
                == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
            )
            .all(dim=3)
            .nonzero()[:, 2]
        )

---

if self.config.get_text_config().hidden_size_per_layer_input:
    if per_layer_inputs is None:  # <-- added line
        pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
        multimodal_mask = multimodal_mask.to(inputs_embeds.device)
        llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
        per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)

---

# fetch embeddings and do with them what you want
inputs_embeds = model.get_input_embeddings()(tokens["input_ids"])

# Option A: pass both `inputs_ids` and `inputs_embeds`
outputs = model(inputs_embeds=inputs_embeds, input_ids=tokens["input_ids"], ...)

# Option B: fetch PLE cheaply from `input_ids`, note that the second argument is then entirely redundant
per_layer_inputs = model.model.language_model.get_per_layer_inputs(tokens["input_ids"], None)
# subsequently pass both `per_layer_inputs` and `inputs_embeds`
outputs = model(inputs_embeds=inputs_embeds, per_layer_inputs=per_layer_inputs, ...)

---

model_id = "google/gemma-4-E2B-it"
gpu_available = torch.cuda.is_available()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if (gpu_available and torch.cuda.is_bf16_supported()) else torch.float16,
    bnb_4bit_use_double_quant=True
)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"": i for i in range(torch.cuda.device_count())} if gpu_available else "auto",
    low_cpu_mem_usage=True,
    attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(model_id)

messages = [
    {"role": "system", "content": "You are a helpful chatbot."},
    {"role": "user", "content": [{"type": "text", "text": "What are the three laws of thermodynamics?"}]}
]
inputs = processor.apply_chat_template(
    messages, 
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True
).to(model.device)

inputs_embeds = model.get_input_embeddings()(inputs["input_ids"])

with torch.no_grad():
    outputs = model( # Used under the hood of `model.generate`
        inputs_embeds=inputs_embeds,
        attention_mask=inputs["attention_mask"],
        use_cache=True,
        logits_to_keep=1,
        past_key_values=None
    )
# OutOfMemoryError: CUDA out of memory. Tried to allocate 135.00 GiB.

---

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[6], line 46
     42 
     43 inputs_embeds = model.get_input_embeddings()(inputs["input_ids"])
     44 
     45 with torch.no_grad():
---> 46     outputs = model(
     47         inputs_embeds=inputs_embeds,
     48         attention_mask=inputs["attention_mask"],
     49         use_cache=True,

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1779, in Module._wrapped_call_impl(self, *args, **kwargs)
   1777     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778 else:
-> 1779     return self._call_impl(*args, **kwargs)

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1790, in Module._call_impl(self, *args, **kwargs)
   1785 # If we don't have any hooks, we want to skip the rest of the logic in
   1786 # this function, and just call forward.
   1787 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1788         or _global_backward_pre_hooks or _global_backward_hooks
   1789         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790     return forward_call(*args, **kwargs)
   1792 result = None
   1793 called_always_called_hooks = set()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:900, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    898 if return_dict_passed is not None:
    899     return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
    901 if not return_dict and not isinstance(output, tuple):
    902     output = output.to_tuple()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:2515, in Gemma4ForConditionalGeneration.forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, image_position_ids, video_position_ids, past_key_values, mm_token_type_ids, inputs_embeds, labels, use_cache, logits_to_keep, **kwargs)
   2484 @can_return_tuple
   2485 @auto_docstring
   2486 def forward(
   (...)   2503     **kwargs: Unpack[TransformersKwargs],
   2504 ) -> Gemma4CausalLMOutputWithPast:
   2505     r"""
   2506     input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
   2507         The attention mask for the input audio.
   (...)   2513         Passed through to the vision encoder for positional embedding computation.
   2514     """
-> 2515     outputs = self.model(
   2516         input_ids=input_ids,
   2517         pixel_values=pixel_values,
   2518         pixel_values_videos=pixel_values_videos,
   2519         input_features=input_features,
   2520         attention_mask=attention_mask,
   2521         input_features_mask=input_features_mask,
   2522         position_ids=position_ids,
   2523         past_key_values=past_key_values,
   2524         mm_token_type_ids=mm_token_type_ids,
   2525         inputs_embeds=inputs_embeds,
   2526         labels=labels,
   2527         use_cache=use_cache,
   2528         image_position_ids=image_position_ids,
   2529         video_position_ids=video_position_ids,
   2530         return_dict=True,
   2531         **kwargs,
   2532     )
   2534     hidden_states = outputs.last_hidden_state
   2535     # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1779, in Module._wrapped_call_impl(self, *args, **kwargs)
   1777     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778 else:
-> 1779     return self._call_impl(*args, **kwargs)

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1790, in Module._call_impl(self, *args, **kwargs)
   1785 # If we don't have any hooks, we want to skip the rest of the logic in
   1786 # this function, and just call forward.
   1787 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1788         or _global_backward_pre_hooks or _global_backward_hooks
   1789         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790     return forward_call(*args, **kwargs)
   1792 result = None
   1793 called_always_called_hooks = set()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:976, in merge_with_config_defaults.<locals>.wrapper(self, *args, **kwargs)
    974             output = func(self, *args, **kwargs)
    975     else:
--> 976         output = func(self, *args, **kwargs)
    977 # Restore original config value
    978 finally:
    979     if is_causal is not None:

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:900, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    898 if return_dict_passed is not None:
    899     return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
    901 if not return_dict and not isinstance(output, tuple):
    902     output = output.to_tuple()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:2282, in Gemma4Model.forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, past_key_values, mm_token_type_ids, inputs_embeds, use_cache, image_position_ids, video_position_ids, **kwargs)
   2280     pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
   2281     llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
-> 2282     per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
   2283 else:
   2284     per_layer_inputs = None

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:1715, in Gemma4TextModel.get_per_layer_inputs(self, input_ids, inputs_embeds)
   1711 if input_ids is None:
   1712     with torch.no_grad():
   1713         input_ids = (
   1714             (
-> 1715                 inputs_embeds[:, :, None, :]
   1716                 == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
   1717             )
   1718             .all(dim=3)
   1719             .nonzero()[:, 2]
   1720         )
   1721         try:
   1722             input_ids = input_ids.view(inputs_embeds.shape[:2])

OutOfMemoryError: CUDA out of memory. Tried to allocate 425.25 GiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Of the allocated memory 7.04 GiB is allocated by PyTorch, and 2.26 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)
RAW_BUFFERClick to expand / collapse

System Info

When inputs_embeds is passed to Gemma4TextModel or Gemma4Model, the model attempts to recover input_ids via a brute-force reverse lookup against the full embedding weight matrix (L1731, modeling_gemma4.py). This allocates a huge intermediate tensor, causing a massive out-of-memory error even on modest sequences. Two complementary fixes are proposed: relaxing the mutual-exclusion check between input_ids and inputs_embeds, and exposing per_layer_inputs in Gemma4Model.forward.

See the reproduction code below.

Root cause

Gemma 4 E2B/E4B requires input_ids to look up Per-Layer Embeddings (PLE). When only inputs_embeds is provided, get_per_layer_inputs (L1710) attempts to recover input_ids by comparing every embedding vector against every row of the full embedding weight matrix (L1731):

if input_ids is None:
    with torch.no_grad():
        input_ids = (
            (
                inputs_embeds[:, :, None, :]
                == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
            )
            .all(dim=3)
            .nonzero()[:, 2]
        )

This materialises a huge boolean tensor, almost inevitably causing an OOM.

Current flow

When the model is called as in the reproduction code below, Gemma4Model.forward is called, which internally delegates some methods to Gemma4TextModel, although the latter has a different signature for its forward method:

Gemma4TextModel.forward

InputBehaviour
only input_idsembed (L1647) -> PLE lookup (L1651) -> decode (L1652/L1691)
only inputs_embedsreverse lookup (L1731), very expensive! -> PLE lookup -> decode
both inputs_embeds and input_idsValueError at L1644
both inputs_embeds and per_layer_inputsdecode immediately (L1652/L1691)

Gemma4Model.forward

InputBehaviour
only input_idsembed (L2231) -> PLE lookup (L2237) -> Gemma4TextModel.forward
only inputs_embedsreverse lookup delegated (via L2237), again very expensive! -> Gemma4TextModel.forward
both inputs_embeds and input_idsValueError at L2221
both inputs_embeds and per_layer_inputsper_layer_inputs silently ignored, falls back to expensive reverse lookup (via L2237)

The last case in Gemma4Model is particularly unexpected. Gemma4TextModel.forward already accepts per_layer_inputs as an explicit parameter to bypass the lookup entirely, with a docstring explaining the design intent. However, Gemma4Model.forward does not expose this parameter, making it impossible to reach the cheap path through the public-facing model.

Proposed fixes

I would love to create pull request. Two independent suggested options:

Option A: allow inputs_embeds and input_ids together

Relax the mutual-exclusion ValueError at L1644 and/or L2221. When both are provided, inputs_embeds is used for the forward pass and input_ids can be used exclusively for the PLE lookup. This is consistent with how many other models in the Transformers library handle this pair, and requires minimal code changes. Note: In the method get_placeholder_mask, inputs_embeds will be ignored.

Option B: expose per_layer_inputs in Gemma4Model.forward

Add per_layer_inputs as an parameter to Gemma4Model.forward and wrap the (reverse) lookup in an if-statement L2237:

if self.config.get_text_config().hidden_size_per_layer_input:
    if per_layer_inputs is None:  # <-- added line
        pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
        multimodal_mask = multimodal_mask.to(inputs_embeds.device)
        llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
        per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)

This is a one-line addition. As mentioned, Gemma4TextModel already handles the case where per_layer_inputs is provided by the user correctly. The gap is thus only with Gemma4Model.

Usage pattern under either fix:

# fetch embeddings and do with them what you want
inputs_embeds = model.get_input_embeddings()(tokens["input_ids"])

# Option A: pass both `inputs_ids` and `inputs_embeds`
outputs = model(inputs_embeds=inputs_embeds, input_ids=tokens["input_ids"], ...)

# Option B: fetch PLE cheaply from `input_ids`, note that the second argument is then entirely redundant
per_layer_inputs = model.model.language_model.get_per_layer_inputs(tokens["input_ids"], None)
# subsequently pass both `per_layer_inputs` and `inputs_embeds`
outputs = model(inputs_embeds=inputs_embeds, per_layer_inputs=per_layer_inputs, ...)

Motivation

The use case is custom embedding manipulation before decoding, e.g., injecting state representations into the input embedding for finetuning with TRL. This requires passing modified inputs_embeds directly to the decoder, which is relatively standard practice across transformer models.

Notably, the official ONNX export of Gemma 4 on Kaggle (https://www.kaggle.com/models/google/gemma-4/onnx) already supports this workflow through a decoupled "embed_session", explicitly requiring the user fetch (per-layer) embeddings before passing them to the decoder. The HF path does not offer equivalent flexibility.

Developer's intent for this functionality is also documented in the per_layer_inputs docstring of Gemma4TextModel.forward. The proposed changes extend this existing mechanism one level up to Gemma4Model, where it is currently absent.

System information

  • transformers version: 5.7.0
  • Platform: Windows-11-10.0.26200-SP0
  • Python version: 3.13.13
  • Huggingface_hub version: 1.13.0
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.11.0+cu130 (CUDA)
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA RTX 1000 Ada Generation Laptop GPU

Who can help?

@zucchini-nlp, @Cyrilvallez, is this worth working on? I would be happy to open a PR.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

model_id = "google/gemma-4-E2B-it"
gpu_available = torch.cuda.is_available()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if (gpu_available and torch.cuda.is_bf16_supported()) else torch.float16,
    bnb_4bit_use_double_quant=True
)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"": i for i in range(torch.cuda.device_count())} if gpu_available else "auto",
    low_cpu_mem_usage=True,
    attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained(model_id)

messages = [
    {"role": "system", "content": "You are a helpful chatbot."},
    {"role": "user", "content": [{"type": "text", "text": "What are the three laws of thermodynamics?"}]}
]
inputs = processor.apply_chat_template(
    messages, 
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
    add_generation_prompt=True
).to(model.device)

inputs_embeds = model.get_input_embeddings()(inputs["input_ids"])

with torch.no_grad():
    outputs = model( # Used under the hood of `model.generate`
        inputs_embeds=inputs_embeds,
        attention_mask=inputs["attention_mask"],
        use_cache=True,
        logits_to_keep=1,
        past_key_values=None
    )
# OutOfMemoryError: CUDA out of memory. Tried to allocate 135.00 GiB.

Traceback

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[6], line 46
     42 
     43 inputs_embeds = model.get_input_embeddings()(inputs["input_ids"])
     44 
     45 with torch.no_grad():
---> 46     outputs = model(
     47         inputs_embeds=inputs_embeds,
     48         attention_mask=inputs["attention_mask"],
     49         use_cache=True,

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1779, in Module._wrapped_call_impl(self, *args, **kwargs)
   1777     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778 else:
-> 1779     return self._call_impl(*args, **kwargs)

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1790, in Module._call_impl(self, *args, **kwargs)
   1785 # If we don't have any hooks, we want to skip the rest of the logic in
   1786 # this function, and just call forward.
   1787 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1788         or _global_backward_pre_hooks or _global_backward_hooks
   1789         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790     return forward_call(*args, **kwargs)
   1792 result = None
   1793 called_always_called_hooks = set()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:900, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    898 if return_dict_passed is not None:
    899     return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
    901 if not return_dict and not isinstance(output, tuple):
    902     output = output.to_tuple()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:2515, in Gemma4ForConditionalGeneration.forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, image_position_ids, video_position_ids, past_key_values, mm_token_type_ids, inputs_embeds, labels, use_cache, logits_to_keep, **kwargs)
   2484 @can_return_tuple
   2485 @auto_docstring
   2486 def forward(
   (...)   2503     **kwargs: Unpack[TransformersKwargs],
   2504 ) -> Gemma4CausalLMOutputWithPast:
   2505     r"""
   2506     input_features_mask (`torch.FloatTensor]` of shape `(num_images, seq_length)`):
   2507         The attention mask for the input audio.
   (...)   2513         Passed through to the vision encoder for positional embedding computation.
   2514     """
-> 2515     outputs = self.model(
   2516         input_ids=input_ids,
   2517         pixel_values=pixel_values,
   2518         pixel_values_videos=pixel_values_videos,
   2519         input_features=input_features,
   2520         attention_mask=attention_mask,
   2521         input_features_mask=input_features_mask,
   2522         position_ids=position_ids,
   2523         past_key_values=past_key_values,
   2524         mm_token_type_ids=mm_token_type_ids,
   2525         inputs_embeds=inputs_embeds,
   2526         labels=labels,
   2527         use_cache=use_cache,
   2528         image_position_ids=image_position_ids,
   2529         video_position_ids=video_position_ids,
   2530         return_dict=True,
   2531         **kwargs,
   2532     )
   2534     hidden_states = outputs.last_hidden_state
   2535     # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1779, in Module._wrapped_call_impl(self, *args, **kwargs)
   1777     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1778 else:
-> 1779     return self._call_impl(*args, **kwargs)

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\torch\nn\modules\module.py:1790, in Module._call_impl(self, *args, **kwargs)
   1785 # If we don't have any hooks, we want to skip the rest of the logic in
   1786 # this function, and just call forward.
   1787 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1788         or _global_backward_pre_hooks or _global_backward_hooks
   1789         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1790     return forward_call(*args, **kwargs)
   1792 result = None
   1793 called_always_called_hooks = set()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:976, in merge_with_config_defaults.<locals>.wrapper(self, *args, **kwargs)
    974             output = func(self, *args, **kwargs)
    975     else:
--> 976         output = func(self, *args, **kwargs)
    977 # Restore original config value
    978 finally:
    979     if is_causal is not None:

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\utils\generic.py:900, in can_return_tuple.<locals>.wrapper(self, *args, **kwargs)
    898 if return_dict_passed is not None:
    899     return_dict = return_dict_passed
--> 900 output = func(self, *args, **kwargs)
    901 if not return_dict and not isinstance(output, tuple):
    902     output = output.to_tuple()

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:2282, in Gemma4Model.forward(self, input_ids, pixel_values, pixel_values_videos, input_features, attention_mask, input_features_mask, position_ids, past_key_values, mm_token_type_ids, inputs_embeds, use_cache, image_position_ids, video_position_ids, **kwargs)
   2280     pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
   2281     llm_inputs_embeds = torch.where(multimodal_mask[..., None], pad_embedding.view(1, 1, -1), inputs_embeds)
-> 2282     per_layer_inputs = self.language_model.get_per_layer_inputs(llm_input_ids, llm_inputs_embeds)
   2283 else:
   2284     per_layer_inputs = None

File c:\Users\thijs\.conda\envs\onnx\Lib\site-packages\transformers\models\gemma4\modeling_gemma4.py:1715, in Gemma4TextModel.get_per_layer_inputs(self, input_ids, inputs_embeds)
   1711 if input_ids is None:
   1712     with torch.no_grad():
   1713         input_ids = (
   1714             (
-> 1715                 inputs_embeds[:, :, None, :]
   1716                 == self.embed_tokens.weight[None, None, :, :] * self.config.hidden_size**0.5
   1717             )
   1718             .all(dim=3)
   1719             .nonzero()[:, 2]
   1720         )
   1721         try:
   1722             input_ids = input_ids.view(inputs_embeds.shape[:2])

OutOfMemoryError: CUDA out of memory. Tried to allocate 425.25 GiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Of the allocated memory 7.04 GiB is allocated by PyTorch, and 2.26 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://docs.pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf)

Expected behavior

As mentioned, I would expect that we can pass our own embeddings/PLEs to the model, consistent with the ONNX implementation, allowing more flexibility. This has also been previously motivated. However, as described, it is not properly implemented for Gemma 4 E2B/E4B.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

As mentioned, I would expect that we can pass our own embeddings/PLEs to the model, consistent with the ONNX implementation, allowing more flexibility. This has also been previously motivated. However, as described, it is not properly implemented for Gemma 4 E2B/E4B.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING