transformers - ✅(Solved) Fix [Gemma 4] Support per-layer FlashAttention: FA2 for sliding layers, SDPA for global layers [2 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45201Fetched 2026-04-08 02:33:15
View on GitHub
Comments
1
Participants
2
Timeline
7
Reactions
2
Timeline (top)
subscribed ×3mentioned ×2commented ×1cross-referenced ×1

Error Message

RuntimeError: FlashAttention only supports head dimensions up to 256 
(got head_dim=512 in layer 3 [global attention layer])

The error occurs at the first global attention layer during the forward pass.

Root Cause

Setting attn_implementation="flash_attention_2" causes a RuntimeError on the 4 global attention layers because FA2 does not support head_dim > 256.

Fix Action

Fix / Workaround

Current workaround

Proposed solution: per-layer attention dispatch

Transformers already supports dict-based attn_implementation for some models. Gemma 4 should use this to dispatch attention per-layer:

PR fix notes

PR #1682: feat: add mock VLM dataset and Gemma4 pretokenize support

Description (problem / solution / changelog)

Summary

  • Add build_mock_vlm_dataset for VLM benchmarking/testing without real data downloads — generates random PIL images + dummy text in conversation format
  • Add truncate mode to PreTokenizedDatasetWrapper (labels built before truncation) for fixed-length sequence training
  • Add Gemma4 tensor support (image_position_ids, mm_token_type_ids) to PreTokenizedDatasetWrapper and pad_collate_fn
  • Auto-enable pretokenize/truncate when max_length is set on dataset config
  • Gemma4's tokenizer defaults to padding_side='left'. Changed to padding=right for training.
  • Changed attn to eager as WAR based on the following issues:
  1. Issue on gemma4 attn backend: https://github.com/huggingface/transformers/issues/45201

  2. Issue on gemma4 attn backend: https://github.com/huggingface/transformers/pull/45202

Test plan

  • Unit tests: pytest tests/unit_tests/datasets/vlm/test_mock.py (11 tests pass)
  • Gemma4 4B mock training: 26 steps, loss 17.08→3.97, no NaN
  • Gemma4 26B MoE mock training: 26 steps on 8 GPUs, loss 11.79→3.96, no NaN
  • Verified per-sample loss identical between default_collate_fn and pad_collate_fn paths

🤖 Generated with Claude Code

Changed files

  • examples/vlm_finetune/gemma4/gemma4_26b_a4b_moe.yaml (modified, +4/-1)
  • examples/vlm_finetune/gemma4/gemma4_26b_a4b_moe_mock.yaml (added, +98/-0)
  • examples/vlm_finetune/gemma4/gemma4_2b.yaml (modified, +4/-1)
  • examples/vlm_finetune/gemma4/gemma4_31b.yaml (modified, +4/-1)
  • examples/vlm_finetune/gemma4/gemma4_4b.yaml (modified, +4/-1)
  • examples/vlm_finetune/gemma4/gemma4_4b_mock.yaml (added, +88/-0)
  • nemo_automodel/components/datasets/vlm/__init__.py (modified, +2/-0)
  • nemo_automodel/components/datasets/vlm/collate_fns.py (modified, +15/-1)
  • nemo_automodel/components/datasets/vlm/datasets.py (modified, +34/-12)
  • nemo_automodel/components/datasets/vlm/mock.py (added, +132/-0)
  • nemo_automodel/recipes/vlm/finetune.py (modified, +8/-3)
  • tests/unit_tests/datasets/vlm/test_mock_vlm.py (added, +307/-0)

Code Example

RuntimeError: FlashAttention only supports head dimensions up to 256 
(got head_dim=512 in layer 3 [global attention layer])

---

{
  "head_dim": 256,
  "global_head_dim": 512,
  "num_attention_heads": 8,
  "num_key_value_heads": 4,
  "sliding_window_pattern": 6,
  "num_hidden_layers": 30
}

---

# Pseudocode for per-layer dispatch in Gemma4Model.__init__
for i, layer in enumerate(self.layers):
    is_global = (i % config.sliding_window_pattern) == (config.sliding_window_pattern // 2)
    if is_global and config._attn_implementation == "flash_attention_2":
        layer._attn_implementation = "sdpa"  # FA2 incompatible (head_dim=512)
    else:
        layer._attn_implementation = "flash_attention_2"  # FA2 compatible (head_dim=256)
RAW_BUFFERClick to expand / collapse

Problem

Gemma 4 (26B-A4B) has a hybrid attention architecture where different layers use different head dimensions:

  • 26 out of 30 layers use sliding window attention with head_dim=256 — fully compatible with FlashAttention 2
  • 4 out of 30 layers use global attention with global_head_dim=512 — exceeds FA2's maximum supported head dimension of 256

Setting attn_implementation="flash_attention_2" causes a RuntimeError on the 4 global attention layers because FA2 does not support head_dim > 256.

Error

RuntimeError: FlashAttention only supports head dimensions up to 256 
(got head_dim=512 in layer 3 [global attention layer])

The error occurs at the first global attention layer during the forward pass.

Config details

From config.json:

{
  "head_dim": 256,
  "global_head_dim": 512,
  "num_attention_heads": 8,
  "num_key_value_heads": 4,
  "sliding_window_pattern": 6,
  "num_hidden_layers": 30
}

With sliding_window_pattern=6, global attention layers are at indices [3, 9, 15, 21] (every 6th layer, 0-indexed, offset by 3). The remaining 26 layers use sliding window attention with head_dim=256.

Current workaround

Fall back to attn_implementation="sdpa" for all 30 layers. This works but sacrifices the FA2 speedup on the 26 sliding layers (~87% of layers) that are fully FA2-compatible.

Proposed solution: per-layer attention dispatch

Transformers already supports dict-based attn_implementation for some models. Gemma 4 should use this to dispatch attention per-layer:

  • Sliding attention layers (26/30): use flash_attention_2
  • Global attention layers (4/30): use sdpa (or eager)

This could be implemented in Gemma4DecoderLayer or at model init time by reading the layer index and sliding_window_pattern from the config:

# Pseudocode for per-layer dispatch in Gemma4Model.__init__
for i, layer in enumerate(self.layers):
    is_global = (i % config.sliding_window_pattern) == (config.sliding_window_pattern // 2)
    if is_global and config._attn_implementation == "flash_attention_2":
        layer._attn_implementation = "sdpa"  # FA2 incompatible (head_dim=512)
    else:
        layer._attn_implementation = "flash_attention_2"  # FA2 compatible (head_dim=256)

Alternatively, the model could automatically detect the FA2 head_dim limit and fall back per-layer, similar to how some models handle mixed attention patterns.

Impact

  • Performance: Users lose ~87% of potential FA2 speedup by falling back to SDPA on all layers
  • Usability: Setting attn_implementation="flash_attention_2" crashes instead of gracefully degrading
  • Scope: Affects all Gemma 4 variants with hybrid sliding/global attention (26B-A4B confirmed)

Environment

  • transformers >= 4.52.0 (Gemma 4 support)
  • flash-attn 2.x (head_dim limit = 256)
  • PyTorch 2.x, CUDA

Related

  • FlashAttention head_dim limit: Dao-AILab/flash-attention#801 (closed — Tri Dao noted head_dim > 256 is nontrivial)
  • FA4 (Hopper) head_dim support: Dao-AILab/flash-attention#2318 (head_dim 256 now works on Hopper fwd, but 512 still unsupported)
  • Transformers already has dict-based attn_implementation support for some model architectures — extending this to Gemma 4 would be the cleanest path

extent analysis

TL;DR

Implement per-layer attention dispatch to use flash_attention_2 for sliding attention layers and sdpa for global attention layers.

Guidance

  • Identify the layer type (sliding or global attention) based on the sliding_window_pattern and layer index.
  • Set attn_implementation to flash_attention_2 for sliding attention layers and sdpa for global attention layers.
  • Modify the Gemma4DecoderLayer or model initialization to support per-layer attention dispatch.
  • Verify the fix by checking the performance and stability of the model with the new attention dispatch implementation.

Example

# Pseudocode for per-layer dispatch in Gemma4Model.__init__
for i, layer in enumerate(self.layers):
    is_global = (i % config.sliding_window_pattern) == (config.sliding_window_pattern // 2)
    if is_global:
        layer._attn_implementation = "sdpa"  # FA2 incompatible (head_dim=512)
    else:
        layer._attn_implementation = "flash_attention_2"  # FA2 compatible (head_dim=256)

Notes

The proposed solution requires modifying the Gemma4DecoderLayer or model initialization to support per-layer attention dispatch. This change should be compatible with the existing transformers library and flash-attn version.

Recommendation

Apply the per-layer attention dispatch workaround to use flash_attention_2 for sliding attention layers and sdpa for global attention layers, as it provides a clean and efficient solution to the problem.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING