transformers - ✅(Solved) Fix Remove unnecessary `expand_as` in `get_placeholder_mask` across VLMs [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44906Fetched 2026-04-08 01:07:54
View on GitHub
Comments
0
Participants
1
Timeline
2
Reactions
1
Author
Participants
Timeline (top)
cross-referenced ×1labeled ×1

Root Cause

The get_placeholder_mask function (and equivalent inline patterns) across ~70 multimodal model files expands a boolean placeholder mask from shape (B, S, 1) to (B, S, H) via .expand_as(inputs_embeds) before passing it to masked_scatter. This expansion is unnecessary because masked_scatter natively supports broadcasting.

Fix Action

Fixed

PR fix notes

PR #44907: Remove unnecessary expand_as in get_placeholder_mask across VLMs

Description (problem / solution / changelog)

Fixes #44906

Summary

  • Remove .expand_as(inputs_embeds) from placeholder mask creation in get_placeholder_mask and equivalent inline patterns across all VLM models. masked_scatter natively broadcasts (B, S, 1)(B, S, H), making the expansion unnecessary.
  • Replace inputs_embeds[special_image_mask].numel() == image_features.numel() validation with equivalent arithmetic n_tokens * inputs_embeds.shape[-1] == image_features.numel(), which avoids data-dependent boolean indexing and is more torch.compile-friendly.
  • 71 files changed across llava, qwen2_vl, paligemma, gemma3n, chameleon, video_llava, idefics2/3, instructblip, blip_2, and many more.

How this was developed

The core fix was first implemented and verified on llava/modeling_llava.py, then expanded to all other models following the same pattern using Claude Code. Each file was reviewed to ensure the transformation was appropriate — files with genuinely different expand_as usage (e.g., pe_audio where the mask is later .reshape()-ed) were left unchanged.

Test plan

  • Correctness verified: masked_scatter with broadcast (B,S,1) mask produces identical results to expanded (B,S,H) mask
  • pytest tests/models/llava/test_modeling_llava.py -x -v -k "not slow" — 136 passed
  • pytest tests/models/qwen2_vl/test_modeling_qwen2_vl.py -x -v -k "not slow" — 137 passed
  • pytest tests/models/paligemma/test_modeling_paligemma.py -x -v -k "not slow" — 124 passed
  • ruff check — all checks passed
  • check_modular_conversion.py --fix_and_overwrite — all generated files consistent
  • No duplicate PRs found (gh pr list --search "expand_as placeholder mask" / "get_placeholder_mask")

This PR uses AI assistance (Claude Code). I have reviewed all changes and validated the behavior end-to-end.

Changed files

  • examples/modular-transformers/modeling_new_task_model.py (modified, +4/-6)
  • src/transformers/integrations/tensor_parallel.py (modified, +2/-2)
  • src/transformers/models/aria/modeling_aria.py (modified, +2/-2)
  • src/transformers/models/aya_vision/modeling_aya_vision.py (modified, +2/-2)
  • src/transformers/models/blip_2/modeling_blip_2.py (modified, +3/-3)
  • src/transformers/models/chameleon/modeling_chameleon.py (modified, +2/-2)
  • src/transformers/models/cohere2_vision/modeling_cohere2_vision.py (modified, +2/-2)
  • src/transformers/models/colqwen2/modeling_colqwen2.py (modified, +1/-3)
  • src/transformers/models/colqwen2/modular_colqwen2.py (modified, +1/-3)
  • src/transformers/models/deepseek_vl/modeling_deepseek_vl.py (modified, +2/-2)
  • src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py (modified, +3/-3)
  • src/transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py (modified, +1/-1)
  • src/transformers/models/emu3/modeling_emu3.py (modified, +2/-2)
  • src/transformers/models/emu3/modular_emu3.py (modified, +2/-2)
  • src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py (modified, +4/-4)
  • src/transformers/models/fast_vlm/modeling_fast_vlm.py (modified, +2/-2)
  • src/transformers/models/florence2/modeling_florence2.py (modified, +2/-2)
  • src/transformers/models/fuyu/modeling_fuyu.py (modified, +2/-2)
  • src/transformers/models/gemma3/modeling_gemma3.py (modified, +2/-2)
  • src/transformers/models/gemma3n/modeling_gemma3n.py (modified, +6/-6)
  • src/transformers/models/gemma3n/modular_gemma3n.py (modified, +6/-6)
  • src/transformers/models/glm46v/modeling_glm46v.py (modified, +4/-4)
  • src/transformers/models/glm4v/modeling_glm4v.py (modified, +4/-4)
  • src/transformers/models/glm4v/modular_glm4v.py (modified, +4/-4)
  • src/transformers/models/glm4v_moe/modeling_glm4v_moe.py (modified, +4/-4)
  • src/transformers/models/glm_ocr/modeling_glm_ocr.py (modified, +4/-4)
  • src/transformers/models/got_ocr2/modeling_got_ocr2.py (modified, +2/-2)
  • src/transformers/models/higgs_audio_v2/modeling_higgs_audio_v2.py (modified, +1/-1)
  • src/transformers/models/higgs_audio_v2/modular_higgs_audio_v2.py (modified, +1/-1)
  • src/transformers/models/idefics2/modeling_idefics2.py (modified, +1/-1)
  • src/transformers/models/idefics3/modeling_idefics3.py (modified, +1/-1)
  • src/transformers/models/instructblip/modeling_instructblip.py (modified, +2/-2)
  • src/transformers/models/instructblipvideo/modeling_instructblipvideo.py (modified, +3/-3)
  • src/transformers/models/instructblipvideo/modular_instructblipvideo.py (modified, +2/-2)
  • src/transformers/models/internvl/modeling_internvl.py (modified, +2/-2)
  • src/transformers/models/janus/modeling_janus.py (modified, +2/-2)
  • src/transformers/models/janus/modular_janus.py (modified, +2/-2)
  • src/transformers/models/lfm2_vl/modeling_lfm2_vl.py (modified, +2/-2)
  • src/transformers/models/lfm2_vl/modular_lfm2_vl.py (modified, +2/-2)
  • src/transformers/models/lighton_ocr/modeling_lighton_ocr.py (modified, +2/-2)
  • src/transformers/models/llama4/modeling_llama4.py (modified, +2/-2)
  • src/transformers/models/llava/modeling_llava.py (modified, +2/-2)
  • src/transformers/models/llava_next/modeling_llava_next.py (modified, +2/-2)
  • src/transformers/models/llava_next_video/modeling_llava_next_video.py (modified, +4/-4)
  • src/transformers/models/llava_next_video/modular_llava_next_video.py (modified, +4/-4)
  • src/transformers/models/llava_onevision/modeling_llava_onevision.py (modified, +4/-4)
  • src/transformers/models/mistral3/modeling_mistral3.py (modified, +2/-2)
  • src/transformers/models/ovis2/modeling_ovis2.py (modified, +3/-7)
  • src/transformers/models/ovis2/modular_ovis2.py (modified, +1/-5)
  • src/transformers/models/paddleocr_vl/modeling_paddleocr_vl.py (modified, +2/-2)
  • src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py (modified, +2/-2)
  • src/transformers/models/paligemma/modeling_paligemma.py (modified, +4/-6)
  • src/transformers/models/perception_lm/modeling_perception_lm.py (modified, +4/-4)
  • src/transformers/models/perception_lm/modular_perception_lm.py (modified, +4/-4)
  • src/transformers/models/pi0/modeling_pi0.py (modified, +1/-4)
  • src/transformers/models/pi0/modular_pi0.py (modified, +1/-4)
  • src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py (modified, +8/-8)
  • src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py (modified, +8/-8)
  • src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py (modified, +4/-4)
  • src/transformers/models/qwen2_audio/modeling_qwen2_audio.py (modified, +1/-1)
  • src/transformers/models/qwen2_vl/modeling_qwen2_vl.py (modified, +4/-4)
  • src/transformers/models/qwen3_5/modeling_qwen3_5.py (modified, +4/-4)
  • src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py (modified, +4/-4)
  • src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py (modified, +5/-5)
  • src/transformers/models/qwen3_vl/modeling_qwen3_vl.py (modified, +4/-4)
  • src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py (modified, +4/-4)
  • src/transformers/models/t5gemma2/modeling_t5gemma2.py (modified, +2/-2)
  • src/transformers/models/t5gemma2/modular_t5gemma2.py (modified, +2/-2)
  • src/transformers/models/video_llama_3/modeling_video_llama_3.py (modified, +4/-4)
  • src/transformers/models/video_llava/modeling_video_llava.py (modified, +4/-4)
  • src/transformers/models/vipllava/modeling_vipllava.py (modified, +2/-2)

Code Example

special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)

---

inputs_embeds[special_image_mask].numel() == image_features.numel()

---

special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device)

---

n_image_tokens * inputs_embeds.shape[-1] == image_features.numel()

---

import torch
B, S, H = 2, 10, 16
inputs_embeds = torch.randn(B, S, H)
features = torch.randn(3, H)
mask_2d = torch.zeros(B, S, dtype=torch.bool)
mask_2d[0, 2], mask_2d[0, 5], mask_2d[1, 3] = True, True, True

mask_old = mask_2d.unsqueeze(-1).expand_as(inputs_embeds)
mask_new = mask_2d.unsqueeze(-1)

result_old = inputs_embeds.clone().masked_scatter(mask_old, features)
result_new = inputs_embeds.clone().masked_scatter(mask_new, features)
assert torch.equal(result_old, result_new)  # ✓ identical
RAW_BUFFERClick to expand / collapse

Feature request

Problem

The get_placeholder_mask function (and equivalent inline patterns) across ~70 multimodal model files expands a boolean placeholder mask from shape (B, S, 1) to (B, S, H) via .expand_as(inputs_embeds) before passing it to masked_scatter. This expansion is unnecessary because masked_scatter natively supports broadcasting.

For example, in modeling_llava.py:

special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)

The validation check also uses data-dependent boolean indexing:

inputs_embeds[special_image_mask].numel() == image_features.numel()

Motivation

Motivation

Mainly, Memory. While expand_as itself returns a stride-0 view (no copy), the subsequent .to(device) call materializes the full (B, S, H) boolean tensor when a device transfer is needed. Practically, I assume that the need for actual device transfer is quite rare; however, I believe this is still a safer implementation.

Your contribution

Proposed fix

  1. Remove .expand_as(inputs_embeds), keeping the mask as (B, S, 1):
special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device)
  1. Replace the validation with equivalent arithmetic:
n_image_tokens * inputs_embeds.shape[-1] == image_features.numel()

Correctness

masked_scatter, torch.where, and element-wise * all support broadcasting a (B, S, 1) mask to (B, S, H). Verified with:

import torch
B, S, H = 2, 10, 16
inputs_embeds = torch.randn(B, S, H)
features = torch.randn(3, H)
mask_2d = torch.zeros(B, S, dtype=torch.bool)
mask_2d[0, 2], mask_2d[0, 5], mask_2d[1, 3] = True, True, True

mask_old = mask_2d.unsqueeze(-1).expand_as(inputs_embeds)
mask_new = mask_2d.unsqueeze(-1)

result_old = inputs_embeds.clone().masked_scatter(mask_old, features)
result_new = inputs_embeds.clone().masked_scatter(mask_new, features)
assert torch.equal(result_old, result_new)  # ✓ identical

Affects get_placeholder_mask and equivalent inline patterns in ~70 files across all VLM models (llava, qwen2_vl, paligemma, gemma3n, chameleon, video_llava, etc.), plus tensor_parallel.py and ovis2.

I have a PR ready: https://github.com/syncdoth/transformers/tree/remove-expand-as-placeholder-mask

extent analysis

Fix Plan

To address the issue, follow these steps:

  • Remove the unnecessary .expand_as(inputs_embeds) call when creating the special_image_mask.
  • Replace the validation check with an equivalent arithmetic operation.

Code Changes

Update the special_image_mask creation:

special_image_mask = special_image_mask.unsqueeze(-1).to(inputs_embeds.device)

Replace the validation check:

n_image_tokens * inputs_embeds.shape[-1] == image_features.numel()

Apply these changes to all occurrences of get_placeholder_mask and equivalent inline patterns in the affected files (~70 files).

Verification

Verify the fix by running the provided test:

import torch
B, S, H = 2, 10, 16
inputs_embeds = torch.randn(B, S, H)
features = torch.randn(3, H)
mask_2d = torch.zeros(B, S, dtype=torch.bool)
mask_2d[0, 2], mask_2d[0, 5], mask_2d[1, 3] = True, True, True

mask_old = mask_2d.unsqueeze(-1).expand_as(inputs_embeds)
mask_new = mask_2d.unsqueeze(-1)

result_old = inputs_embeds.clone().masked_scatter(mask_old, features)
result_new = inputs_embeds.clone().masked_scatter(mask_new, features)
assert torch.equal(result_old, result_new)  # ✓ identical

Ensure the assertion passes, indicating that the fix does not affect the functionality.

Extra Tips

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING