transformers - ✅(Solved) Fix Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45206Fetched 2026-04-08 02:33:10
View on GitHub
Comments
0
Participants
1
Timeline
2
Reactions
4
Participants
Timeline (top)
cross-referenced ×1subscribed ×1

I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.

Root Cause

This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)

Fix Action

Fixed

PR fix notes

PR #45207: [Gemma4] Add docstrings for Per-Layer Embeddings (PLE) pipeline

Description (problem / solution / changelog)

Fixes #45206

What does this PR do?

Adds documentation for the Gemma4 Per-Layer Embeddings (PLE) system, which is currently pretty hard to reverse-engineer from the code alone.

I ran into this while implementing Gemma4 inference from scratch in Rust. The PLE system has several non-obvious aspects that aren't documented anywhere:

  1. hidden_size_per_layer_input (256) is the per-layer dimension, but the actual embedding weight is [vocab, num_layers * 256] = [262144, 8960] because all layers are packed
  2. The embedding is a Gemma4TextScaledWordEmbedding that silently multiplies by sqrt(256) = 16 - this took me a while to track down
  3. The full pipeline has a context-aware projection step (per_layer_model_projection + scale + RMSNorm) that combines with the token lookup before being passed to layers, with specific scale factors (1/sqrt(hidden_size) and 1/sqrt(2))

This PR adds:

  • Expanded config docstring for hidden_size_per_layer_input explaining the packed layout, scaling, and full pipeline
  • Docstrings for get_per_layer_inputs() and project_per_layer_inputs()
  • A comment on the PLE init block pointing to the pipeline methods

Hopefully this saves some pain for others implementing Gemma4 outside of transformers.

Changed files

  • src/transformers/models/gemma4/configuration_gemma4.py (modified, +15/-4)
  • src/transformers/models/gemma4/modeling_gemma4.py (modified, +24/-0)
RAW_BUFFERClick to expand / collapse

Description

I was implementing Gemma4 inference from scratch (in Rust) and the Per-Layer Embeddings (PLE) system was by far the hardest part to get right. The config fields are misleading, the embedding type is non-obvious, and the full pipeline involves several undocumented steps. Sharing this in case it helps others and in case you want to improve the docs.

Problem 1: hidden_size_per_layer_input is ambiguous

The config says hidden_size_per_layer_input: 256, which sounds like it's the embedding dimension. But embed_tokens_per_layer.weight has shape [262144, 8960] where 8960 = 35 layers * 256. The actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone.

This confused me because the __init__ in Gemma4TextModel seems like it should create nn.Embedding(vocab, 256) but then loading the pretrained weight of shape [vocab, 8960] would fail. (It doesn't fail because from_pretrained handles the resize, but it's not obvious from reading the code.)

Problem 2: embed_tokens_per_layer is secretly a Gemma4TextScaledWordEmbedding

The PLE embedding isn't a plain nn.Embedding. It's a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input) = sqrt(256) = 16.0.

This isn't mentioned anywhere in the config, the docstrings, or the model card. I only found it by inspecting type(lm.embed_tokens_per_layer).__name__ after my outputs were 16x too small.

Problem 3: The full PLE pipeline has undocumented steps

The actual PLE computation involves:

  1. Token-identity: embed_tokens_per_layer(input_ids) (scaled by sqrt(256)) -> reshape to [B, S, num_layers, ple_dim]
  2. Context-aware projection: per_layer_model_projection(inputs_embeds) (a Linear) -> scale by 1/sqrt(hidden_size) -> reshape to [B, S, num_layers, ple_dim] -> RMSNorm (per_layer_projection_norm)
  3. Combine: (context_projection + token_identity) * (1/sqrt(2))
  4. Each layer i gets per_layer_inputs[:, :, i, :]

This involves weights that aren't mentioned in the config at all:

  • per_layer_model_projection (Linear, hidden_size -> num_layers * ple_dim)
  • per_layer_projection_norm (RMSNorm, dim=ple_dim)
  • Two hardcoded scale factors: 1/sqrt(hidden_size) and 1/sqrt(2)

The get_per_layer_inputs() and project_per_layer_inputs() methods implement this, but there are no docstrings explaining the overall pipeline or the scale factors.

Suggestion

Adding a docstring to Gemma4TextModel (or the config class) explaining:

  1. That hidden_size_per_layer_input is the per-layer dimension, and the total embedding dim is num_hidden_layers * hidden_size_per_layer_input
  2. That the PLE embedding is scaled by sqrt(hidden_size_per_layer_input)
  3. A brief description of the full PLE pipeline (token lookup + context projection + norm + combine with scale factors)

This would save a lot of pain for anyone implementing Gemma4 outside of HuggingFace transformers (e.g. llama.cpp, candle, mlx, etc.).

Environment

  • transformers 5.5.0
  • Model: google/gemma-4-E2B-it

extent analysis

TL;DR

To fix the Per-Layer Embeddings (PLE) system implementation issues in Gemma4, clarify the config fields and embedding type, and document the full pipeline involving several undocumented steps.

Guidance

  • Verify the actual embedding dimension is num_hidden_layers * hidden_size_per_layer_input, not hidden_size_per_layer_input alone, to avoid confusion.
  • Recognize that the PLE embedding is a Gemma4TextScaledWordEmbedding that multiplies the lookup result by sqrt(hidden_size_per_layer_input), which is not explicitly mentioned in the config or docstrings.
  • Document the full PLE pipeline, including token-identity, context-aware projection, and combination steps, to provide a clear understanding of the process.
  • Consider adding docstrings to Gemma4TextModel or the config class to explain the PLE system, including the per-layer dimension, scaling factors, and pipeline steps.

Notes

The provided information is specific to the Gemma4 model and the transformers library version 5.5.0, so the guidance may not apply to other models or versions.

Recommendation

Apply a workaround by carefully documenting and explaining the PLE system, including the config fields, embedding type, and full pipeline, to avoid similar implementation issues in the future. This will help others implementing Gemma4 outside of HuggingFace transformers.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix Gemma4: PLE (Per-Layer Embeddings) implementation is underdocumented and config is misleading [1 pull requests, 1 participants]