transformers - ✅(Solved) Fix T5Gemma2: decoder self-attention fixed 4097-element mask at batch=1, fails on inputs >4094 tokens [1 pull requests, 2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45521Fetched 2026-04-20 11:58:46
View on GitHub
Comments
2
Participants
2
Timeline
10
Reactions
0
Timeline (top)
subscribed ×4commented ×2mentioned ×2cross-referenced ×1

Error Message

File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 544, in forward hidden_states, _, _ = self.self_attn( File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 431, in forward attn_output, attn_weights = attention_interface( File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 244, in eager_attention_forward attn_weights = attn_weights + attention_mask RuntimeError: The size of tensor a (4097) must match the size of tensor b (5018) at non-singleton dimension 3

Fix Action

Fix / Workaround

Workarounds tried

  • attn_implementation="sdpa" → same class of shape mismatch at longer lengths
  • attn_implementation="flash_attention_2"ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet (from _flash_attn_can_dispatch)
  • max_input_tokens ≤ 4094 avoids the crash but caps an advertised-128K-context model at 4K, which is much less than the sliding_window config would suggest

PR fix notes

PR #45540: Fix cross-attention cache layer type for T5Gemma2 long inputs

Description (problem / solution / changelog)

Fixes #45521. Cross-attention in T5Gemma2ForConditionalGeneration is supposed to attend to all encoder tokens, but for inputs whose encoder length is >= sliding_window (default 4096) generation crashes with:

RuntimeError: The size of tensor a (4097) must match the size of tensor b (5018) at non-singleton dimension 3

The root cause was in T5Gemma2ForConditionalGeneration._prepare_cache_for_generation, the cross-attention config was being stripped of its sliding-window settings via del:

cross_attn_config = copy.deepcopy(self.config.get_text_config(decoder=True))
del cross_attn_config.sliding_window
del cross_attn_config.layer_types

T5Gemma2DecoderConfig with defaults sliding_window: int | None = 4096 and layer_types: list[str] | None = None. Removing the instance attributes therefore makes attribute lookup fall back to those class defaults, so cross_attn_config once again is sliding_window=4096.

DynamicCache.__init__ sees sliding_window=4096 with layer_types=None will auto-derives layer_types = ["sliding_attention"] * num_hidden_layers, and instantiates DynamicSlidingWindowLayer for every cross-attention layer. On update, those layers truncate the encoder K/V states to the last sliding_window-1 tokens:

self.keys   = full_key_states[:,   :, -self.sliding_window + 1 :, :]
self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]

So when enc_len == 4096, the cached cross-attention keys end up with shape [..., 4095, head_dim], which (after concatenation with the decoder self-attention key in T5Gemma2MergedAttention.forward) yields an attn_weights last-dim of 4097. Hence the mismatch.

Fix

Explicitly set sliding_window to null and layer_types to full attention for all layers, instead of deleting the instance attributes.

Tests

  • Added test T5Gemma2ModelTest::test_cross_attention_cache_is_not_sliding, which asserts that after generate() every layer of output.past_key_values.cross_attention_cache is DynamicLayer. Confirmed test fails on main branch and passes on this branch.
  • tests/models/t5gemma2/test_modeling_t5gemma2.py passes.
  • Verified provided end-to-end reproducer passed after the fix.
python /tmp/transformers_bug_repro.py 
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████| 1327/1327 [00:04<00:00, 323.79it/s]

--- target=2500 ---
[transformers] The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
OK  (input=2500, output=17)

--- target=3500 ---
OK  (input=3500, output=17)

--- target=4000 ---
OK  (input=4000, output=17)

--- target=4090 ---
OK  (input=4090, output=17)

--- target=4100 ---
OK  (input=4100, output=17)

--- target=4500 ---
OK  (input=4500, output=17)

--- target=5000 ---
OK  (input=5000, output=17)

--- target=8000 ---
OK  (input=8000, output=17)
<!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable -->

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by code agents. We are currently bottlenecked by our ability to review and respond to them. As a result, we ask that new users do not submit pure code agent PRs at this time. You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result, this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ If you know how to use git blame, that is the easiest way, otherwise, here is a rough guide of **who to tag**. Please tag fewer than 3 people. Models: - text models: @ArthurZucker @Cyrilvallez - vision models: @yonigozlan @molbap - audio models: @eustlb @ebezzam @vasqu - multimodal models: @zucchini-nlp - graph models: @clefourrier Library: - generate: @zucchini-nlp (visual-language models) or @gante (all others) - continuous batching: @remi-or @ArthurZucker @McPatate - pipelines: @Rocketknight1 - tokenizers: @ArthurZucker and @itazap - trainer: @SunMarc - attention: @vasqu @ArthurZucker @CyrilVallez - model loading (from pretrained, etc): @CyrilVallez - distributed: @3outeille @ArthurZucker - CIs: @ydshieh Integrations: - ray/raytune: @richardliaw, @amogkam - Big Model Inference: @SunMarc - quantization: @SunMarc - kernels: @drbh - peft: @BenjaminBossan @githubnemo Devices/Backends: - AMD ROCm: @ivarflakstad - Intel XPU: @IlyasMoutawwakil - Ascend NPU: @ivarflakstad Documentation: @stevhliu Research projects are not maintained and should be taken as is. -->

Changed files

  • src/transformers/models/t5gemma2/modeling_t5gemma2.py (modified, +2/-2)
  • src/transformers/models/t5gemma2/modular_t5gemma2.py (modified, +2/-2)
  • tests/models/t5gemma2/test_modeling_t5gemma2.py (modified, +41/-0)

Code Example

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

MODEL = "google/t5gemma-2-4b-4b"
tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL, dtype=torch.bfloat16, attn_implementation="eager", device_map="auto",
).eval()

filler = "| a | b | c |\n|---|---|---|\n| 1 | 2 | 3 |\n"
for n in (3500, 4090, 4100, 5000, 8000):
    prompt = "Answer.\n\nTable:\n" + filler * (n // 20) + "\n\nQ: sum?"
    ids = tok(prompt, return_tensors="pt", truncation=True, max_length=n).input_ids.to(model.device)
    try:
        with torch.no_grad():
            out = model.generate(ids, max_new_tokens=16, do_sample=False)
        print(f"OK  len={ids.shape[-1]}")
    except RuntimeError as e:
        print(f"FAIL len={ids.shape[-1]}: {e}")

---

File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 544, in forward
    hidden_states, _, _ = self.self_attn(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 431, in forward
    attn_output, attn_weights = attention_interface(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 244, in eager_attention_forward
    attn_weights = attn_weights + attention_mask
RuntimeError: The size of tensor a (4097) must match the size of
tensor b (5018) at non-singleton dimension 3

---

sliding_window: 1024
_sliding_window_pattern: 6
max_position_embeddings: 131072
RAW_BUFFERClick to expand / collapse

System Info

  • transformers 5.0.0 (T5Gemma 2 support shipped in #41834)
  • torch 2.8.0, CUDA 12.8.1, Python 3.12
  • Hardware: 1× NVIDIA H100 NVL 94 GB (reproduced on same bug under A100 80 GB SXM)
  • Model: google/t5gemma-2-4b-4b (gated)
  • Base image: runpod/pytorch:1.0.3-cu1281-torch280-ubuntu2404

Who can help?

@ArthurZucker @gante — attention / generation internals for T5Gemma 2.

Reproduction

AutoModelForSeq2SeqLM.from_pretrained(..., attn_implementation="eager") + model.generate() at batch_size=1 raises a shape-mismatch on the decoder self-attention as soon as the input exceeds ~4094 tokens. Batching is not required to trigger this — batch=1 is enough.

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

MODEL = "google/t5gemma-2-4b-4b"
tok = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL, dtype=torch.bfloat16, attn_implementation="eager", device_map="auto",
).eval()

filler = "| a | b | c |\n|---|---|---|\n| 1 | 2 | 3 |\n"
for n in (3500, 4090, 4100, 5000, 8000):
    prompt = "Answer.\n\nTable:\n" + filler * (n // 20) + "\n\nQ: sum?"
    ids = tok(prompt, return_tensors="pt", truncation=True, max_length=n).input_ids.to(model.device)
    try:
        with torch.no_grad():
            out = model.generate(ids, max_new_tokens=16, do_sample=False)
        print(f"OK  len={ids.shape[-1]}")
    except RuntimeError as e:
        print(f"FAIL len={ids.shape[-1]}: {e}")

Expected behavior

T5Gemma 2's decoder config advertises max_position_embeddings: 131072, so generate() should handle inputs well beyond 4K tokens at batch=1. (Qwen 2.5 7B Instruct via vLLM on the same TReB prompts handles up to the dataset max of 28,117 tokens with no issue.)

Actual behavior

On real TReB English samples (JT-LM/JIUTIAN-TReB), measured 2026-04-20:

input tokensresultRuntimeError "b" value
2497OK
3525OK
5015FAIL5018
6499FAIL6502
7493FAIL7496
9997FAIL10000
14967FAIL14970
19808FAIL19811
25135FAIL25138

Error (abridged):

File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 544, in forward
    hidden_states, _, _ = self.self_attn(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 431, in forward
    attn_output, attn_weights = attention_interface(
File ".../transformers/models/t5gemma2/modeling_t5gemma2.py", line 244, in eager_attention_forward
    attn_weights = attn_weights + attention_mask
RuntimeError: The size of tensor a (4097) must match the size of
tensor b (5018) at non-singleton dimension 3

Key fingerprint: tensor a (4097) is constant across every failure, regardless of actual input length; tensor b (N) equals input_length + 3 exactly (three special tokens added by the tokenizer).

Hypothesis

Decoder config has:

sliding_window: 1024
_sliding_window_pattern: 6
max_position_embeddings: 131072

The constant 4097 = 4 × 1024 + 1 looks like a pre-allocated 4-window attention buffer + 1 query token. It isn't being resized when the input exceeds 4096 tokens — attn_weights keeps the fixed 4097 shape while attention_mask tracks the actual sequence length.

Looking at T5Gemma2Decoder.forward (~line 1070-1100 in modeling_t5gemma2.py), decoder self-attention masks come from create_causal_mask / create_sliding_window_causal_mask in transformers/masking_utils.py, called with self.config + inputs_embeds + position_ids. The mismatch between the (4097-wide) attention buffer and the mask that comes back from one of these helpers is probably where this originates. T5Gemma 2's decoder is also documented as using merged self+cross attention with the comment "we always need a mask during decoding for merged attention" (around line 1072) — this merging may be interacting with the sliding-window mask constructor in a way the other Gemma variants don't hit.

Workarounds tried

  • attn_implementation="sdpa" → same class of shape mismatch at longer lengths
  • attn_implementation="flash_attention_2"ValueError: T5Gemma2ForConditionalGeneration does not support Flash Attention 2 yet (from _flash_attn_can_dispatch)
  • max_input_tokens ≤ 4094 avoids the crash but caps an advertised-128K-context model at 4K, which is much less than the sliding_window config would suggest

Full test harness + data (public)

The reproducer above is the minimal version. If it helps, the full testing we did is in a public repo — real TReB English samples, specific IDs, eval driver and monitor: https://github.com/junos-ai-org/jiutian-treb/blob/experiment-setup/experiments/t5gemma_vs_qwen_treb/insights/transformers_bug_repro.py

The 9 failing sample IDs above come from the public HF dataset JT-LM/JIUTIAN-TReB (English split), so anyone can load them and reproduce end-to-end without additional setup. Raw prediction outputs and error logs from our diagnostic run are in experiments/t5gemma_vs_qwen_treb/results/t5gemma_base_threshold/ (gitignored, available on request).

Related (all closed as completed)

Same Gemma-family class of bug, each independently filed and fixed:

  • huggingface/transformers#37219 — RecurrentGemma crashes past SWA width
  • huggingface/transformers#35290 — Custom 4D tensor shape mismatch (dim 3, same flavor)
  • huggingface/transformers#31931 — Gemma 2 BF16 inference
  • huggingface/transformers#41875 — Flash Attention in Seq2SeqLM.generate
  • vllm-project/vllm#14881 — Gemma 3 batch + SWA

Thanks!

extent analysis

TL;DR

The issue can be addressed by modifying the attention implementation or the sliding window configuration in the T5Gemma 2 model to accommodate input lengths exceeding 4096 tokens.

Guidance

  1. Investigate the create_causal_mask and create_sliding_window_causal_mask functions in transformers/masking_utils.py to understand how the attention masks are generated and how they interact with the sliding_window configuration.
  2. Experiment with different attention implementations, such as attn_implementation="sdpa" or other available options, to see if they can handle longer input lengths without shape mismatches.
  3. Consider modifying the sliding_window configuration to increase the window size or adjust the _sliding_window_pattern to better accommodate longer input sequences.
  4. Review the related issues listed in the problem description, such as huggingface/transformers#37219 and huggingface/transformers#35290, to see if the fixes or workarounds applied to those issues can be adapted to this case.

Example

No specific code example is provided, as the issue requires a deeper understanding of the T5Gemma 2 model's attention mechanism and the transformers library's implementation details.

Notes

The issue appears to be specific to the T5Gemma 2 model and its attention implementation, and may require modifications to the model's configuration or the transformers library's code to resolve.

Recommendation

Apply a workaround by modifying the attention implementation or the sliding window configuration, as a fix may require changes to the transformers library or the T5Gemma 2 model's architecture.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

T5Gemma 2's decoder config advertises max_position_embeddings: 131072, so generate() should handle inputs well beyond 4K tokens at batch=1. (Qwen 2.5 7B Instruct via vLLM on the same TReB prompts handles up to the dataset max of 28,117 tokens with no issue.)

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING