transformers - 💡(How to fix) Fix Gemma-4 training with FSDP2 raises `KeyError` in `Gemma4TextAttention.forward` because `shared_kv_states` is rebuilt per-layer [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45663Fetched 2026-04-28 06:24:48
View on GitHub
Comments
1
Participants
2
Timeline
11
Reactions
0
Timeline (top)
subscribed ×5mentioned ×4commented ×1labeled ×1

Error Message

cast_forward_inputs=True File ".../transformers/models/gemma4/modeling_gemma4.py", line 1218, in forward key_states, value_states = shared_kv_states[self.kv_shared_layer_index] KeyError: 0

Root Cause

_FSDPState._pre_forward (torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L243-L251) under cast_forward_inputs=True:

args, kwargs = (
    _apply_to_tensors(cast_fn, args),
    _apply_to_tensors(cast_fn, kwargs),
)

_apply_to_tensors (torch/distributed/utils.py#L245-L246) rebuilds every dict it recurses into via {k: apply(v) for k, v in x.items()} — even when no tensor inside actually needs casting. So every per-decoder-layer wrapper hands the layer a freshly-allocated shared_kv_states... layer 22 mutates one orphan, layer 24 reads another.

Code Example

"""`torchrun --nproc_per_node=1 repro.py` → KeyError(0)."""

import os
import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformers import Gemma4TextConfig, Gemma4TextModel

dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
cast = os.environ.get("CAST_FORWARD_INPUTS", "1") == "1"
print(f"cast_forward_inputs={cast}", flush=True)

# 2-layer Gemma-4: layer 0 writes shared_kv_states[0], layer 1 reads it.
#   * num_kv_shared_layers=1default 0 means no sharing layer, bug can't fire.
#   * layer_types must be set (default None breaks model init); both must be
#     full_attention because Gemma-4 force-overrides the last layer to full,
#     so the first must match for the sharing layer's same-type source lookup.
config = Gemma4TextConfig(
    num_hidden_layers=2,
    num_kv_shared_layers=1,
    layer_types=["full_attention", "full_attention"],
)
model = Gemma4TextModel(config).cuda().train()
model.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs={"use_reentrant": True}
)

mesh = init_device_mesh("cuda", (dist.get_world_size(),))
mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16, cast_forward_inputs=cast)
for layer in model.layers:  # per-layer fully_shard triggers the per-call rebuild
    fully_shard(layer, mesh=mesh, mp_policy=mp)
fully_shard(model, mesh=mesh, mp_policy=mp)

out = model(input_ids=torch.randint(0, config.vocab_size, (1, 8), device="cuda"))
print(f"PASS — last_hidden_state.shape={tuple(out.last_hidden_state.shape)}", flush=True)
dist.destroy_process_group()

---

cast_forward_inputs=True
File ".../transformers/models/gemma4/modeling_gemma4.py", line 1218, in forward
    key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
KeyError: 0

---

File "transformers/modeling_layers.py", line 92, in __call__
    return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
File "torch/utils/checkpoint.py", line 268, in CheckpointFunction.forward
    outputs = run_function(*args)
File "transformers/models/gemma4/modeling_gemma4.py", line 1219, in Gemma4TextAttention.forward
    key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
KeyError: 22

---

GEMMA4_KV_DEBUG layer_idx=22 store_full_length_kv=True               dict_id=15592302014016 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=15592302338496 dict_len=0
GEMMA4_KV_DEBUG layer_idx=23 store_full_length_kv=True               dict_id=11041820011008 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=11041820015872 dict_len=0

---

args, kwargs = (
    _apply_to_tensors(cast_fn, args),
    _apply_to_tensors(cast_fn, kwargs),
)
RAW_BUFFERClick to expand / collapse

System Info

  • transformers version: 5.6.2
  • Platform: Linux-6.8.0-1043-nvidia-x86_64-with-glibc2.35
  • Python version: 3.12.13
  • Huggingface_hub version: 1.11.0
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.10.0+cu129 (CUDA)
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@ArthurZucker @Cyrilvallez @3outeille

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

A 2-layer Gemma-4 from scratch (no checkpoint download), single GPU. Layer 1 is a KV-sharing layer that reads shared_kv_states[0] after layer 0 (a KV-source) was supposed to populate it.

"""`torchrun --nproc_per_node=1 repro.py` → KeyError(0)."""

import os
import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp import MixedPrecisionPolicy, fully_shard
from torch.distributed.device_mesh import init_device_mesh
from transformers import Gemma4TextConfig, Gemma4TextModel

dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
cast = os.environ.get("CAST_FORWARD_INPUTS", "1") == "1"
print(f"cast_forward_inputs={cast}", flush=True)

# 2-layer Gemma-4: layer 0 writes shared_kv_states[0], layer 1 reads it.
#   * num_kv_shared_layers=1 — default 0 means no sharing layer, bug can't fire.
#   * layer_types must be set (default None breaks model init); both must be
#     full_attention because Gemma-4 force-overrides the last layer to full,
#     so the first must match for the sharing layer's same-type source lookup.
config = Gemma4TextConfig(
    num_hidden_layers=2,
    num_kv_shared_layers=1,
    layer_types=["full_attention", "full_attention"],
)
model = Gemma4TextModel(config).cuda().train()
model.gradient_checkpointing_enable(
    gradient_checkpointing_kwargs={"use_reentrant": True}
)

mesh = init_device_mesh("cuda", (dist.get_world_size(),))
mp = MixedPrecisionPolicy(param_dtype=torch.bfloat16, cast_forward_inputs=cast)
for layer in model.layers:  # per-layer fully_shard triggers the per-call rebuild
    fully_shard(layer, mesh=mesh, mp_policy=mp)
fully_shard(model, mesh=mesh, mp_policy=mp)

out = model(input_ids=torch.randint(0, config.vocab_size, (1, 8), device="cuda"))
print(f"PASS — last_hidden_state.shape={tuple(out.last_hidden_state.shape)}", flush=True)
dist.destroy_process_group()

Output:

cast_forward_inputs=True
File ".../transformers/models/gemma4/modeling_gemma4.py", line 1218, in forward
    key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
KeyError: 0

Expected behavior

Gemma4TextModel.forward creates shared_kv_states = {} once per forward (modeling_gemma4.py#L1668-L1669) and threads it through every decoder-layer call so later "sharing" layers can read earlier "source" layers' KV that were written by in-place mutation.

Under FSDP2 this breaks: each fully_shard-wrapped decoder layer's pre-forward traverses kwargs and rebuilds the inner dict. Layer 22's in-place write lands in an orphan; layer 25's read of shared_kv_states[22] raises KeyError: 22. The error message points at modeling_gemma4.py but the rebuild is entirely in the FSDP path.

Note that Gemma4TextModel already declares _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] (modeling_gemma4.py#L1445), so this kwarg is special, but it's not used or passed to FSDP2.

Here's the stack trace I actually see:

File "transformers/modeling_layers.py", line 92, in __call__
    return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
File "torch/utils/checkpoint.py", line 268, in CheckpointFunction.forward
    outputs = run_function(*args)
File "transformers/models/gemma4/modeling_gemma4.py", line 1219, in Gemma4TextAttention.forward
    key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
KeyError: 22

Per-layer dict-identity dump (one line per layer per rank, same forward) shows every layer sees a different dict_id:

GEMMA4_KV_DEBUG layer_idx=22 store_full_length_kv=True               dict_id=15592302014016 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=15592302338496 dict_len=0
GEMMA4_KV_DEBUG layer_idx=23 store_full_length_kv=True               dict_id=11041820011008 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=11041820015872 dict_len=0

Layer 22 is correctly configured (the upstream __init__ math says store_full_length_kv=True), but its mutation is lost because the dict it writes to is not the same dict layer 24/25 reads from.

Root cause

_FSDPState._pre_forward (torch/distributed/fsdp/_fully_shard/_fsdp_state.py#L243-L251) under cast_forward_inputs=True:

args, kwargs = (
    _apply_to_tensors(cast_fn, args),
    _apply_to_tensors(cast_fn, kwargs),
)

_apply_to_tensors (torch/distributed/utils.py#L245-L246) rebuilds every dict it recurses into via {k: apply(v) for k, v in x.items()} — even when no tensor inside actually needs casting. So every per-decoder-layer wrapper hands the layer a freshly-allocated shared_kv_states... layer 22 mutates one orphan, layer 24 reads another.

extent analysis

TL;DR

The most likely fix is to modify the _apply_to_tensors function in torch/distributed/utils.py to avoid rebuilding dictionaries when no tensor inside needs casting.

Guidance

  • Identify the _apply_to_tensors function in torch/distributed/utils.py and modify it to check if any tensor inside the dictionary needs casting before rebuilding the dictionary.
  • Consider adding a special case for the shared_kv_states dictionary to avoid rebuilding it, as it is already declared in _skip_keys_device_placement in Gemma4TextModel.
  • Verify that the modified function correctly handles dictionaries with and without tensors that need casting.
  • Test the modified function with the provided reproduction script to ensure that the KeyError is resolved.

Example

def _apply_to_tensors(cast_fn, x):
    if isinstance(x, dict):
        # Check if any tensor inside the dictionary needs casting
        if any(isinstance(v, torch.Tensor) and v.dtype != cast_fn(v).dtype for v in x.values()):
            return {k: cast_fn(v) if isinstance(v, torch.Tensor) else v for k, v in x.items()}
        else:
            return x
    # ... rest of the function remains the same

Notes

The provided reproduction script and stack trace suggest that the issue is caused by the rebuilding of the shared_kv_states dictionary in the _apply_to_tensors function. Modifying this function to avoid rebuilding dictionaries when no tensor inside needs casting should resolve the KeyError.

Recommendation

Apply the workaround by modifying the _apply_to_tensors function as described above, as the root cause is identified in the PyTorch library and a fix is not available in the current version.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Gemma4TextModel.forward creates shared_kv_states = {} once per forward (modeling_gemma4.py#L1668-L1669) and threads it through every decoder-layer call so later "sharing" layers can read earlier "source" layers' KV that were written by in-place mutation.

Under FSDP2 this breaks: each fully_shard-wrapped decoder layer's pre-forward traverses kwargs and rebuilds the inner dict. Layer 22's in-place write lands in an orphan; layer 25's read of shared_kv_states[22] raises KeyError: 22. The error message points at modeling_gemma4.py but the rebuild is entirely in the FSDP path.

Note that Gemma4TextModel already declares _skip_keys_device_placement = ["past_key_values", "shared_kv_states"] (modeling_gemma4.py#L1445), so this kwarg is special, but it's not used or passed to FSDP2.

Here's the stack trace I actually see:

File "transformers/modeling_layers.py", line 92, in __call__
    return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
File "torch/utils/checkpoint.py", line 268, in CheckpointFunction.forward
    outputs = run_function(*args)
File "transformers/models/gemma4/modeling_gemma4.py", line 1219, in Gemma4TextAttention.forward
    key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
KeyError: 22

Per-layer dict-identity dump (one line per layer per rank, same forward) shows every layer sees a different dict_id:

GEMMA4_KV_DEBUG layer_idx=22 store_full_length_kv=True               dict_id=15592302014016 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=15592302338496 dict_len=0
GEMMA4_KV_DEBUG layer_idx=23 store_full_length_kv=True               dict_id=11041820011008 dict_len=0
GEMMA4_KV_DEBUG layer_idx=24 is_kv_shared=True kv_shared_layer_index=22 dict_id=11041820015872 dict_len=0

Layer 22 is correctly configured (the upstream __init__ math says store_full_length_kv=True), but its mutation is lost because the dict it writes to is not the same dict layer 24/25 reads from.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING