transformers - ✅(Solved) Fix [Gemma3] NaN embeddings on GPU when batching sequences of mixed length (sliding window attention + all-padding windows) [1 pull requests, 1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45491Fetched 2026-04-18 05:51:48
View on GitHub
Comments
0
Participants
1
Timeline
6
Reactions
0
Participants
Timeline (top)
subscribed ×3mentioned ×2labeled ×1

Error Message

The model's sliding window size is 512. With 728 − 142 = 586 padding tokens appended, the first attention windows of Chunk 0 are 100% padding, causing softmax([-inf, -inf, …]) = 0/0 = NaN in the GPU kernel — a silent numerical error (no exception raised, no warning). Actual behavior: Chunk 0 (shortest, ~142 tokens) returns a vector of 768 NaN values when batched with longer sequences. No exception or warning is raised.

Fix Action

Fix / Workaround

Workaround: detect NaN after encoding and re-encode the affected inputs with batch_size=1.

PR fix notes

PR #45511: Fix NaN in Gemma3/EmbeddingGemma when batching mixed-length sequences…

Description (problem / solution / changelog)

This fixes the NaN issue when batching mixed-length sequences with sliding window attention (e.g., with EmbeddingGemma / Gemma3 models).

The patch ensures that when a query position's entire sliding window falls within the padding region (all keys masked with -inf), the attention bias/mask is adjusted to allow all keys (bias=0). This produces a valid uniform softmax distribution instead of softmax([-inf, ...]) = NaN.

This matches the intended behavior for padding positions, which are always excluded from downstream pooling/embedding computation and thus do not affect the final output. The change prevents silent numerical instability on certain GPU SDPA backends and in eager-mode attention, while keeping CPU behavior and single-sequence (batch_size=1) results unchanged.

Fixes #45491.

What does this PR do?

When a short sequence is padded to match a much longer one in a batch (common in batched inference with variable-length inputs), some query positions in the short sequence may have their entire sliding window (e.g., size 512 for Gemma3 local attention layers) fall completely within the padding tokens of the longer sequences.

This results in an attention score row of all -inf, causing softmax to produce NaN (0/0) on affected GPU kernels and in some eager implementations. The NaNs then propagate to the final embeddings (e.g., all-NaN vectors for the shortest items in a batch when using SentenceTransformer.encode_query() or similar).

The fix detects these all-masked rows in the sliding window mask construction and sets the bias to 0 for those positions, yielding a uniform attention distribution. Since these positions correspond exclusively to padding tokens, they are masked out or ignored in subsequent mean-pooling / embedding extraction steps, so the output remains correct and identical to per-example encoding.

This resolves the reported bug for google/EmbeddingGemma-300M (and other Gemma3 variants using sliding window attention) on GPU with mixed-length batches, without changing behavior for valid (non-padding) positions or non-sliding-window models.

Related issues / context

  • Root cause: Sliding window attention + dynamic padding in batched processing → all-padding windows → numerical instability in softmax.
  • Affects: Models with sliding_window in their config (Gemma3 local layers, window size typically 512 or similar).
  • Previously worked around by forcing batch_size=1 or re-encoding NaN results individually.

No new dependencies. No breaking changes.

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by code agents. ...

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
    → Discussed in #45491
  • Did you make sure to update the documentation with your changes? ...
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

Recommended tags (attention + Gemma-related maintainers — keep under 3 where possible):

  • @ArthurZucker (text models, attention, Gemma expertise)
  • @Cyrilvallez (text models, model loading, attention)
  • @vasqu (attention)

Changed files

  • src/transformers/masking_utils.py (modified, +8/-0)

Code Example

from sentence_transformers import SentenceTransformer
import numpy as np

model = SentenceTransformer("google/EmbeddingGemma-300M")  # GPU required

chunks = [
    "ipso literam...",  # ~142 tokens
    "ipso literam...",                      # ~452 tokens
    "ipso literam...",           # ~555 tokens
    "ipso literam...",                # ~728 tokens
]

# BatchChunk 0 produces all-NaN on GPU
embs = model.encode_query(chunks, batch_size=64)
for i, emb in enumerate(embs):
    print(f"Chunk {i}: NaN={np.any(np.isnan(emb))}")
# Output:
# Chunk 0: NaN=True768/768 NaN values
# Chunk 1: NaN=False
# Chunk 2: NaN=False
# Chunk 3: NaN=False

# Individual — all correct
for chunk in chunks:
    emb = model.encode_query([chunk], batch_size=1)
    print(f"NaN={np.any(np.isnan(emb))}")  # False for all
RAW_BUFFERClick to expand / collapse

System Info

System Info

  • transformers: 4.45.1
  • sentence-transformers: 5.1.2
  • tokenizers: 0.20.0
  • safetensors: 0.4.5
  • PyTorch: ≥ 2.6.0
  • Serving runtime: pytorch/torchserve-kfs:0.12.0 (Python 3.9, Linux x86_64)
  • GPU: NVIDIA (CUDA, via KServe / TorchServe on Kubernetes)
  • CPU inference: not affected

Description

When encoding a batch of sentences of mixed lengths with google/EmbeddingGemma (300M) via SentenceTransformer.encode_query(), short sentences receive an all-NaN embedding vector on GPU. The same inputs encoded individually (batch_size=1) return correct, fully-normalized embeddings.

Trigger condition: the shortest sequence in the batch must be short enough that, when padded to the length of the longest sequence, one or more sliding-window attention windows fall entirely on padding tokens.

In the reported case:

  • Chunk 0: 142 tokens → padded to 728 tokens → NaN=768/768 in the batch
  • Chunks 1–3: 452 / 555 / 728 tokens → valid embeddings

The model's sliding window size is 512. With 728 − 142 = 586 padding tokens appended, the first attention windows of Chunk 0 are 100% padding, causing softmax([-inf, -inf, …]) = 0/0 = NaN in the GPU kernel — a silent numerical error (no exception raised, no warning).

Actual behavior: Chunk 0 (shortest, ~142 tokens) returns a vector of 768 NaN values when batched with longer sequences. No exception or warning is raised.

Workaround: detect NaN after encoding and re-encode the affected inputs with batch_size=1.

Who can help?

@ArthurZucker

@Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Steps to reproduce

from sentence_transformers import SentenceTransformer
import numpy as np

model = SentenceTransformer("google/EmbeddingGemma-300M")  # GPU required

chunks = [
    "ipso literam...",  # ~142 tokens
    "ipso literam...",                      # ~452 tokens
    "ipso literam...",           # ~555 tokens
    "ipso literam...",                # ~728 tokens
]

# Batch — Chunk 0 produces all-NaN on GPU
embs = model.encode_query(chunks, batch_size=64)
for i, emb in enumerate(embs):
    print(f"Chunk {i}: NaN={np.any(np.isnan(emb))}")
# Output:
# Chunk 0: NaN=True   ← 768/768 NaN values
# Chunk 1: NaN=False
# Chunk 2: NaN=False
# Chunk 3: NaN=False

# Individual — all correct
for chunk in chunks:
    emb = model.encode_query([chunk], batch_size=1)
    print(f"NaN={np.any(np.isnan(emb))}")  # False for all

Expected behavior

Expected behavior: all embeddings in the batch are valid (no NaN).

extent analysis

TL;DR

Detect NaN values after encoding and re-encode the affected inputs with batch_size=1 to workaround the issue.

Guidance

  • Identify the shortest sequence in the batch and check if its length is short enough to cause the sliding-window attention windows to fall entirely on padding tokens.
  • Verify if the softmax function is producing NaN values due to silent numerical errors when computing attention weights.
  • Implement a check for NaN values in the embedding vectors after encoding and re-encode the affected inputs individually.
  • Consider modifying the batching logic to avoid mixing sequences of significantly different lengths.

Example

import numpy as np

# ... (rest of the code remains the same)

embs = model.encode_query(chunks, batch_size=64)
for i, emb in enumerate(embs):
    if np.any(np.isnan(emb)):
        print(f"Chunk {i} has NaN values, re-encoding individually")
        emb = model.encode_query([chunks[i]], batch_size=1)
        print(f"Re-encoded embedding: NaN={np.any(np.isnan(emb))}")

Notes

This workaround may have performance implications due to the additional encoding step. The root cause of the issue seems to be related to the softmax function producing NaN values when computing attention weights, which may be specific to the google/EmbeddingGemma model or the PyTorch/TorchServe setup.

Recommendation

Apply the workaround by detecting NaN values and re-encoding the affected inputs individually, as this is a safe and effective way to mitigate the issue without modifying the underlying model or infrastructure.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Expected behavior: all embeddings in the batch are valid (no NaN).

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix [Gemma3] NaN embeddings on GPU when batching sequences of mixed length (sliding window attention + all-padding windows) [1 pull requests, 1 participants]