transformers - 💡(How to fix) Fix [Gemma4] `Gemma4VisionPatchEmbedder._position_embeddings` materializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Fix Action

Fix / Workaround

The current implementation of Gemma4VisionPatchEmbedder._position_embeddings in src/transformers/models/gemma4/modeling_gemma4.py (line 561 in 5.6.2) computes 2D patch position embeddings via one-hot encoding followed by a batched matmul:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map for matmul with positon embedding table."""
    clamped_positions = pixel_position_ids.clamp(min=0)
    one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
    one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
    position_embeddings = one_hot @ self.position_embedding_table
    position_embeddings = position_embeddings.sum(dim=1)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

For Gemma-4-E4B-it (position_embedding_size=10240), with batch_size=4 × 10 images/sample × 2520 patches/image (the default max_soft_tokens=280 * pooling_kernel_size^2=9), this materializes:

Code Example

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map for matmul with positon embedding table."""
    clamped_positions = pixel_position_ids.clamp(min=0)
    one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
    one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
    position_embeddings = one_hot @ self.position_embedding_table
    position_embeddings = position_embeddings.sum(dim=1)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

---

[1] alloc peak 36.00 GiB 
      19.57 GiB  (54.4%)  x10    vision tower (vision_model)
      14.09 GiB  (39.1%)        ZeRO-3 
       ...

---

one_hot @ line 581:   15.38 GiB (single int64 tensor, [40, 2520, 2, 10240])
one_hot.to bf16 @ 582: 3.85 GiB (bf16 cast of the same shape)

---

_position_embeddings @ modeling_gemma4.py:581
forward @ modeling_gemma4.py:596

---

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map via direct embedding lookup.

    Mathematically equivalent to the original one_hot @ table + sum,
    but avoids the [batch, num_patches, 2, position_embedding_size]
    one-hot temporary (~15 GiB int64 + ~3.85 GiB bf16 cast at bs=4,
    n_img=10, position_embedding_size=10240).
    """
    # Original .clamp(min=0) only bounds the lower side; one_hot() did
    # an implicit upper-bound check. F.embedding would read OOB on
    # overflow, so clamp both sides.
    clamped = pixel_position_ids.clamp(0, self.position_embedding_size - 1)
    x_emb = F.embedding(clamped[..., 0], self.position_embedding_table[0])
    y_emb = F.embedding(clamped[..., 1], self.position_embedding_table[1])
    position_embeddings = x_emb + y_emb
    # The original's trailing .sum(dim=1) promotes bf16 -> fp32 under
    # autocast (reduction-dtype rule); plain `+` doesn't, so cast
    # explicitly only inside autocast to keep the downstream
    # `hidden_states + pos` dtype contract.
    if torch.is_autocast_enabled() and position_embeddings.dtype != torch.float32:
        position_embeddings = position_embeddings.to(torch.float32)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings
RAW_BUFFERClick to expand / collapse

System Info

  • transformers version: 5.6.2
  • Platform: Linux 5.15.0-41-generic (x86_64)
  • Python: 3.12
  • PyTorch: 2.10.0+cu128
  • GPU: 8 × NVIDIA H800 (80GB)
  • Model: google/gemma-4-E4B-it (any Gemma 4 multimodal variant)
  • Trainer: ZeRO-3 (DeepSpeed) full-parameter SFT, torch_dtype: bfloat16

Who can help?

@yonigozlan @molbap (vision models) @zucchini-nlp (multimodal models)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The current implementation of Gemma4VisionPatchEmbedder._position_embeddings in src/transformers/models/gemma4/modeling_gemma4.py (line 561 in 5.6.2) computes 2D patch position embeddings via one-hot encoding followed by a batched matmul:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map for matmul with positon embedding table."""
    clamped_positions = pixel_position_ids.clamp(min=0)
    one_hot = F.one_hot(clamped_positions, num_classes=self.position_embedding_size)
    one_hot = one_hot.permute(0, 2, 1, 3).to(self.position_embedding_table)
    position_embeddings = one_hot @ self.position_embedding_table
    position_embeddings = position_embeddings.sum(dim=1)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

For Gemma-4-E4B-it (position_embedding_size=10240), with batch_size=4 × 10 images/sample × 2520 patches/image (the default max_soft_tokens=280 * pooling_kernel_size^2=9), this materializes:

  • one_hot: [40, 2520, 2, 10240] int64 → 15.38 GiB
  • one_hot.to(bf16) cast copy: same shape, bf16 → 3.85 GiB

Total ~19 GiB just for two embedding lookups. Captured via torch.cuda.memory._record_memory_history peak breakdown:

[1] alloc peak 36.00 GiB 
      19.57 GiB  (54.4%)  x10    vision tower (vision_model)
      14.09 GiB  (39.1%)        ZeRO-3 
       ...

Drilling into the 10 vision_model allocations:

one_hot @ line 581:   15.38 GiB (single int64 tensor, [40, 2520, 2, 10240])
one_hot.to bf16 @ 582: 3.85 GiB (bf16 cast of the same shape)

Stack tops:

_position_embeddings @ modeling_gemma4.py:581
forward @ modeling_gemma4.py:596

The one_hot @ table pattern is mathematically a 2-row embedding lookup (two table rows indexed by clamped[..., 0] and clamped[..., 1], summed) that can be replaced with F.embedding without materializing any of the [..., position_embedding_size] intermediates.

Expected behavior

_position_embeddings should produce the same result without the intermediate one-hot tensor. The following replacement is numerically equivalent (bit-exact outside autocast; matches the original under autocast to within bf16-matmul roundoff ~1.5e-2 absolute) and eliminates both the 15.38 GiB int64 tensor and its 3.85 GiB bf16 cast:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map via direct embedding lookup.

    Mathematically equivalent to the original one_hot @ table + sum,
    but avoids the [batch, num_patches, 2, position_embedding_size]
    one-hot temporary (~15 GiB int64 + ~3.85 GiB bf16 cast at bs=4,
    n_img=10, position_embedding_size=10240).
    """
    # Original .clamp(min=0) only bounds the lower side; one_hot() did
    # an implicit upper-bound check. F.embedding would read OOB on
    # overflow, so clamp both sides.
    clamped = pixel_position_ids.clamp(0, self.position_embedding_size - 1)
    x_emb = F.embedding(clamped[..., 0], self.position_embedding_table[0])
    y_emb = F.embedding(clamped[..., 1], self.position_embedding_table[1])
    position_embeddings = x_emb + y_emb
    # The original's trailing .sum(dim=1) promotes bf16 -> fp32 under
    # autocast (reduction-dtype rule); plain `+` doesn't, so cast
    # explicitly only inside autocast to keep the downstream
    # `hidden_states + pos` dtype contract.
    if torch.is_autocast_enabled() and position_embeddings.dtype != torch.float32:
        position_embeddings = position_embeddings.to(torch.float32)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

_position_embeddings should produce the same result without the intermediate one-hot tensor. The following replacement is numerically equivalent (bit-exact outside autocast; matches the original under autocast to within bf16-matmul roundoff ~1.5e-2 absolute) and eliminates both the 15.38 GiB int64 tensor and its 3.85 GiB bf16 cast:

def _position_embeddings(self, pixel_position_ids, padding_positions):
    """Prepare patch positions map via direct embedding lookup.

    Mathematically equivalent to the original one_hot @ table + sum,
    but avoids the [batch, num_patches, 2, position_embedding_size]
    one-hot temporary (~15 GiB int64 + ~3.85 GiB bf16 cast at bs=4,
    n_img=10, position_embedding_size=10240).
    """
    # Original .clamp(min=0) only bounds the lower side; one_hot() did
    # an implicit upper-bound check. F.embedding would read OOB on
    # overflow, so clamp both sides.
    clamped = pixel_position_ids.clamp(0, self.position_embedding_size - 1)
    x_emb = F.embedding(clamped[..., 0], self.position_embedding_table[0])
    y_emb = F.embedding(clamped[..., 1], self.position_embedding_table[1])
    position_embeddings = x_emb + y_emb
    # The original's trailing .sum(dim=1) promotes bf16 -> fp32 under
    # autocast (reduction-dtype rule); plain `+` doesn't, so cast
    # explicitly only inside autocast to keep the downstream
    # `hidden_states + pos` dtype contract.
    if torch.is_autocast_enabled() and position_embeddings.dtype != torch.float32:
        position_embeddings = position_embeddings.to(torch.float32)
    position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
    return position_embeddings

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - 💡(How to fix) Fix [Gemma4] `Gemma4VisionPatchEmbedder._position_embeddings` materializes a ~19 GiB one-hot tensor that's mathematically a 2-row embedding lookup