transformers - 💡(How to fix) Fix Gemma 4: Exploding pre-clip gradient norms during LoRA fine-tuning of `gemma-4-31B-it` [2 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45676Fetched 2026-04-29 06:11:25
View on GitHub
Comments
2
Participants
2
Timeline
9
Reactions
0
Author
Timeline (top)
commented ×2labeled ×2mentioned ×2subscribed ×2

Fine-tuning google/gemma-4-31B-it with a small LoRA via standard transformers.Trainer + peft on a public chat-style dataset produces pre-clip gradient norms that are 1–3 orders of magnitude larger than expected. With max_grad_norm=1.0 the actual updates are bounded, but the pre-clip values swing between ~0.5 and ~30, and on at least one step in 30 the pre-clip norm spiked to 312 with no obvious correlation to loss.

For comparison, the same training script on Qwen3-32B / Llama-3.3-70B / similar dense ~30B models produces pre-clip grad norms < 1.0 from step 0 and never exceeds ~2.0.

I've also reproduced the same pattern in two completely independent training stacks (Unsloth FastModel and a custom FSDP-2 setup), so this looks like an issue in the transformers Gemma 4 modeling code rather than any one downstream framework.

Error Message

"""Minimal HF transformers + PEFT reproducer for the Gemma 4 31B exploding gradient-norm issue.

Dependencies (single, separate venv): pip install "torch>=2.6.0" transformers peft accelerate datasets bitsandbytes

Run on one H200/A100 (≥80 GB recommended): CUDA_VISIBLE_DEVICES=0 python repro_gemma4_grad_spike_hf_peft.py """

from future import annotations

import os import json

import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model from transformers import ( AutoModelForImageTextToText, AutoTokenizer, Trainer, TrainingArguments, )

---------------------------------------------------------------------------

Config

---------------------------------------------------------------------------

MODEL_ID = "google/gemma-4-31B-it" DATASET_ID = "mlabonne/FineTome-100k" DATASET_SLICE = "train[:2000]" MAX_SEQ_LEN = 4096 OUTPUT_DIR = "outputs/hf_peft_grad_spike_repro"

LORA_R = 16 LORA_ALPHA = 32

PER_DEVICE_BATCH = 1 GRAD_ACCUM = 8 LR = 1e-4 WARMUP_STEPS = 5 MAX_STEPS = 30 WEIGHT_DECAY = 0.0 MAX_GRAD_NORM = 1.0

Disable math-SDPA (OOMs on Gemma 4's head_dim=512 globals); force mem-efficient.

torch.backends.cuda.enable_math_sdp(False) import transformers.integrations.sdpa_attention as _sdpa_mod # noqa: E402 _sdpa_mod.use_gqa_in_sdpa = lambda *a, **kw: False

SHAREGPT_TO_OAI_ROLE = { "human": "user", "user": "user", "input": "user", "gpt": "assistant", "assistant": "assistant", "output": "assistant", "system": "system", }

def to_oai(conv): out = [] for msg in conv: role_raw = msg.get("from") or msg.get("role") content = msg.get("value") or msg.get("content") role = SHAREGPT_TO_OAI_ROLE.get(role_raw, "user") if content is None: continue out.append({"role": role, "content": str(content)}) while out and out[0]["role"] not in ("system", "user"): out.pop(0) return out

def main(): print(f"Loading tokenizer + model: {MODEL_ID}") tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) if tok.pad_token is None: tok.pad_token = tok.eos_token

model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    attn_implementation="sdpa",
    trust_remote_code=True,
)
model.config.use_cache = False
model.enable_input_require_grads()

print(f"\nApplying LoRA: r={LORA_R} alpha={LORA_ALPHA}")
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.0,
    target_modules=r".*language_model\..*\.(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)",
    task_type="CAUSAL_LM",
    bias="none",
)
model = get_peft_model(model, lora_config)
model.gradient_checkpointing_enable()
model.print_trainable_parameters()

print(f"\nLoading dataset: {DATASET_ID} ({DATASET_SLICE})")
ds = load_dataset(DATASET_ID, split=DATASET_SLICE)

def tokenize_row(example):
    conv = example.get("conversations") or example.get("messages") or []
    messages = to_oai(conv)
    if not messages:
        return None
    try:
        enc = tok.apply_chat_template(
            messages, tokenize=True, add_generation_prompt=False,
            truncation=True, max_length=MAX_SEQ_LEN,
            padding="max_length", return_dict=True,
        )
    except Exception:
        return None
    input_ids = enc["input_ids"]
    attn_mask = enc["attention_mask"]
    labels = [(-100 if m == 0 else x) for x, m in zip(input_ids, attn_mask)]
    return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}

print("Tokenizing...")
ds = ds.map(tokenize_row, remove_columns=ds.column_names, num_proc=4)
ds = ds.filter(lambda x: x.get("input_ids") is not None)

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LR,
    lr_scheduler_type="cosine",
    warmup_steps=WARMUP_STEPS,
    max_steps=MAX_STEPS,
    logging_steps=1,
    bf16=True,
    optim="adamw_torch",
    weight_decay=WEIGHT_DECAY,
    max_grad_norm=MAX_GRAD_NORM,
    save_strategy="no",
    report_to="none",
    seed=42,
    gradient_checkpointing=False,
    remove_unused_columns=False,
)

trainer = Trainer(model=model, args=args, train_dataset=ds)
stats = trainer.train()

history = trainer.state.log_history
log_file = os.path.join(OUTPUT_DIR, "trainer_log_history.json")
with open(log_file, "w") as f:
    json.dump(history, f, indent=2)

print("\nSTEP   LOSS         GRAD_NORM    LR")
for entry in history:
    if "loss" in entry and "grad_norm" in entry:
        print(f"{entry.get('step', '?'):>4}   {entry['loss']:<10.4f}   "
              f"{entry['grad_norm']:<11.4f}  {entry.get('learning_rate', 0):.2e}")
print(f"\nTraining runtime: {stats.metrics.get('train_runtime', 0):.1f}s")

if name == "main": main()

Root Cause

Fine-tuning google/gemma-4-31B-it with a small LoRA via standard transformers.Trainer + peft on a public chat-style dataset produces pre-clip gradient norms that are 1–3 orders of magnitude larger than expected. With max_grad_norm=1.0 the actual updates are bounded, but the pre-clip values swing between ~0.5 and ~30, and on at least one step in 30 the pre-clip norm spiked to 312 with no obvious correlation to loss.

For comparison, the same training script on Qwen3-32B / Llama-3.3-70B / similar dense ~30B models produces pre-clip grad norms < 1.0 from step 0 and never exceeds ~2.0.

I've also reproduced the same pattern in two completely independent training stacks (Unsloth FastModel and a custom FSDP-2 setup), so this looks like an issue in the transformers Gemma 4 modeling code rather than any one downstream framework.

Fix Action

Fix / Workaround

Happy to test patches against this repro and report back. The whole training loop runs in ~14 minutes on one H200, so iteration is fast.

Code Example

# fresh venv, no other dependencies
pip install "torch>=2.6.0" transformers peft accelerate datasets bitsandbytes

# run on one H200/A100 (80 GB recommended for the 31B in bf16)
CUDA_VISIBLE_DEVICES=0 python repro_gemma4_grad_spike_hf_peft.py

---

STEP   LOSS         GRAD_NORM    LR
   1   4.0690       29.16        0.00e+00
   2   4.9466       11.40        2.00e-05
   3   3.5648       30.76        4.00e-05      <- early instability
   4   4.3582       23.92        6.00e-05
   5   3.3984       21.58        8.00e-05
   6   3.2813        5.39        1.00e-04
   7   3.2394        5.26        9.96e-05
   8   2.7875      312.30        9.84e-05      <-- 312 PRE-CLIP, ~70× the running mean
   9   1.8185        4.11        9.65e-05
  10   1.6028        2.85        9.38e-05
  11   1.4414        2.86        9.05e-05
  12   1.4170        2.34        8.64e-05
  13   1.3700        2.62        8.19e-05
  14   1.4069        1.86        7.68e-05
  15   1.0819        1.17        7.13e-05
  16   1.1695        1.43        6.55e-05
  17   1.0539        0.95        5.94e-05
  18   1.0098        0.76        5.31e-05
  19   1.1840        0.99        4.69e-05
  20   1.0833        0.67        4.06e-05
  21   1.0692        0.95        3.45e-05
  22   0.9248        0.69        2.87e-05
  23   0.9762        0.79        2.32e-05
  24   0.8180        0.46        1.81e-05
  25   1.0037        0.63        1.36e-05
  26   0.8292        0.58        9.55e-06
  27   0.9903        0.61        6.18e-06
  28   1.1112        1.18        3.51e-06
  29   0.8287        0.52        1.57e-06
  30   1.0775        0.62        3.94e-07

---

{'loss': 1.46,  'grad_norm':    14.98}
{'loss': 1.28,  'grad_norm':     9.78}
{'loss': 1.60,  'grad_norm':     7.64}
{'loss': 1.14,  'grad_norm':    92.10}   <-- spike #1 at step 4
{'loss': 1.54,  'grad_norm':    29.00}
{'loss': 1.61,  'grad_norm':     5.21}
{'loss': 1.36,  'grad_norm':     7.83}
{'loss': 1.36,  'grad_norm':     7.79}
{'loss': 1.51,  'grad_norm': 11140.00}   <-- spike #2 at step 91000× neighbors
{'loss': 1.27,  'grad_norm':    15.06}

---

attn=fa2  step  0: grad_norm = 326    step 7:  63
attn=fa3  step  0: grad_norm = 332
attn=sdpa step  0: grad_norm = 9.94   step 3: 1040  step 12: 26

---

torch                2.11.0+cu130
transformers         5.6.2
peft                 0.19.1
datasets             4.8.5
accelerate           1.13.0
flash-attn           (not required; repro uses SDPA)
GPU                  H200 143 GB (CUDA_VISIBLE_DEVICES=0; also tested on 8×H200)
CUDA                 13.0
Python               3.12.3

---

"""Minimal HF transformers + PEFT reproducer for the Gemma 4 31B exploding
gradient-norm issue.

Dependencies (single, separate venv):
    pip install "torch>=2.6.0" transformers peft accelerate datasets bitsandbytes

Run on one H200/A100 (80 GB recommended):
    CUDA_VISIBLE_DEVICES=0 python repro_gemma4_grad_spike_hf_peft.py
"""

from __future__ import annotations

import os
import json

import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForImageTextToText,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_ID = "google/gemma-4-31B-it"
DATASET_ID = "mlabonne/FineTome-100k"
DATASET_SLICE = "train[:2000]"
MAX_SEQ_LEN = 4096
OUTPUT_DIR = "outputs/hf_peft_grad_spike_repro"

LORA_R = 16
LORA_ALPHA = 32

PER_DEVICE_BATCH = 1
GRAD_ACCUM = 8
LR = 1e-4
WARMUP_STEPS = 5
MAX_STEPS = 30
WEIGHT_DECAY = 0.0
MAX_GRAD_NORM = 1.0


# Disable math-SDPA (OOMs on Gemma 4's head_dim=512 globals); force mem-efficient.
torch.backends.cuda.enable_math_sdp(False)
import transformers.integrations.sdpa_attention as _sdpa_mod  # noqa: E402
_sdpa_mod.use_gqa_in_sdpa = lambda *a, **kw: False


SHAREGPT_TO_OAI_ROLE = {
    "human": "user", "user": "user", "input": "user",
    "gpt": "assistant", "assistant": "assistant", "output": "assistant",
    "system": "system",
}


def to_oai(conv):
    out = []
    for msg in conv:
        role_raw = msg.get("from") or msg.get("role")
        content = msg.get("value") or msg.get("content")
        role = SHAREGPT_TO_OAI_ROLE.get(role_raw, "user")
        if content is None:
            continue
        out.append({"role": role, "content": str(content)})
    while out and out[0]["role"] not in ("system", "user"):
        out.pop(0)
    return out


def main():
    print(f"Loading tokenizer + model: {MODEL_ID}")
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        attn_implementation="sdpa",
        trust_remote_code=True,
    )
    model.config.use_cache = False
    model.enable_input_require_grads()

    print(f"\nApplying LoRA: r={LORA_R} alpha={LORA_ALPHA}")
    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=0.0,
        target_modules=r".*language_model\..*\.(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)",
        task_type="CAUSAL_LM",
        bias="none",
    )
    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()
    model.print_trainable_parameters()

    print(f"\nLoading dataset: {DATASET_ID} ({DATASET_SLICE})")
    ds = load_dataset(DATASET_ID, split=DATASET_SLICE)

    def tokenize_row(example):
        conv = example.get("conversations") or example.get("messages") or []
        messages = to_oai(conv)
        if not messages:
            return None
        try:
            enc = tok.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=False,
                truncation=True, max_length=MAX_SEQ_LEN,
                padding="max_length", return_dict=True,
            )
        except Exception:
            return None
        input_ids = enc["input_ids"]
        attn_mask = enc["attention_mask"]
        labels = [(-100 if m == 0 else x) for x, m in zip(input_ids, attn_mask)]
        return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}

    print("Tokenizing...")
    ds = ds.map(tokenize_row, remove_columns=ds.column_names, num_proc=4)
    ds = ds.filter(lambda x: x.get("input_ids") is not None)

    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        lr_scheduler_type="cosine",
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,
        logging_steps=1,
        bf16=True,
        optim="adamw_torch",
        weight_decay=WEIGHT_DECAY,
        max_grad_norm=MAX_GRAD_NORM,
        save_strategy="no",
        report_to="none",
        seed=42,
        gradient_checkpointing=False,
        remove_unused_columns=False,
    )

    trainer = Trainer(model=model, args=args, train_dataset=ds)
    stats = trainer.train()

    history = trainer.state.log_history
    log_file = os.path.join(OUTPUT_DIR, "trainer_log_history.json")
    with open(log_file, "w") as f:
        json.dump(history, f, indent=2)

    print("\nSTEP   LOSS         GRAD_NORM    LR")
    for entry in history:
        if "loss" in entry and "grad_norm" in entry:
            print(f"{entry.get('step', '?'):>4}   {entry['loss']:<10.4f}   "
                  f"{entry['grad_norm']:<11.4f}  {entry.get('learning_rate', 0):.2e}")
    print(f"\nTraining runtime: {stats.metrics.get('train_runtime', 0):.1f}s")


if __name__ == "__main__":
    main()
RAW_BUFFERClick to expand / collapse

System Info

Summary

Fine-tuning google/gemma-4-31B-it with a small LoRA via standard transformers.Trainer + peft on a public chat-style dataset produces pre-clip gradient norms that are 1–3 orders of magnitude larger than expected. With max_grad_norm=1.0 the actual updates are bounded, but the pre-clip values swing between ~0.5 and ~30, and on at least one step in 30 the pre-clip norm spiked to 312 with no obvious correlation to loss.

For comparison, the same training script on Qwen3-32B / Llama-3.3-70B / similar dense ~30B models produces pre-clip grad norms < 1.0 from step 0 and never exceeds ~2.0.

I've also reproduced the same pattern in two completely independent training stacks (Unsloth FastModel and a custom FSDP-2 setup), so this looks like an issue in the transformers Gemma 4 modeling code rather than any one downstream framework.

Reproduction

A single self-contained script (full source at the end of this issue):

# fresh venv, no other dependencies
pip install "torch>=2.6.0" transformers peft accelerate datasets bitsandbytes

# run on one H200/A100 (≥80 GB recommended for the 31B in bf16)
CUDA_VISIBLE_DEVICES=0 python repro_gemma4_grad_spike_hf_peft.py

Settings (deliberately mirroring what a normal user would pick — nothing exotic):

  • attn_implementation="sdpa", bf16, no quantization
  • max_seq_length = 4096, public dataset (mlabonne/FineTome-100k, 2000 rows)
  • LoRA r=16, α=32, applied to q/k/v/o/gate/up/down inside language_model
  • per_device_train_batch_size=1, gradient_accumulation_steps=8 → effective batch 8
  • lr=1e-4 cosine, warmup_steps=5, max_grad_norm=1.0
  • optim="adamw_torch", max_steps=30, logging_steps=1

Observed behavior (full 30-step trace from one run)

STEP   LOSS         GRAD_NORM    LR
   1   4.0690       29.16        0.00e+00
   2   4.9466       11.40        2.00e-05
   3   3.5648       30.76        4.00e-05      <- early instability
   4   4.3582       23.92        6.00e-05
   5   3.3984       21.58        8.00e-05
   6   3.2813        5.39        1.00e-04
   7   3.2394        5.26        9.96e-05
   8   2.7875      312.30        9.84e-05      <-- 312 PRE-CLIP, ~70× the running mean
   9   1.8185        4.11        9.65e-05
  10   1.6028        2.85        9.38e-05
  11   1.4414        2.86        9.05e-05
  12   1.4170        2.34        8.64e-05
  13   1.3700        2.62        8.19e-05
  14   1.4069        1.86        7.68e-05
  15   1.0819        1.17        7.13e-05
  16   1.1695        1.43        6.55e-05
  17   1.0539        0.95        5.94e-05
  18   1.0098        0.76        5.31e-05
  19   1.1840        0.99        4.69e-05
  20   1.0833        0.67        4.06e-05
  21   1.0692        0.95        3.45e-05
  22   0.9248        0.69        2.87e-05
  23   0.9762        0.79        2.32e-05
  24   0.8180        0.46        1.81e-05
  25   1.0037        0.63        1.36e-05
  26   0.8292        0.58        9.55e-06
  27   0.9903        0.61        6.18e-06
  28   1.1112        1.18        3.51e-06
  29   0.8287        0.52        1.57e-06
  30   1.0775        0.62        3.94e-07

Highlights:

  • Steps 1–5: pre-clip grad norms 11–30, well above what's normal for cold-start LoRA.
  • Step 8: pre-clip norm = 312 while loss=2.79 (lower than steps 1–5). No obvious data-side cause; the surrounding steps (6, 7, 9, 10) all have grad_norm < 6.
  • Steps 11–30: settles to 0.5–3, but with intermittent bumps (e.g. step 28: 1.18 — out of place vs neighbors 0.5–0.9).

The clipping at 1.0 keeps actual optimizer updates bounded, but a "raw" pre-clip norm that swings by 70× in a single step suggests something in the backward pass is producing very loud gradients on certain inputs.

Expected behavior

Pre-clip grad norms should be in the same ballpark as other modern dense ~30B instruction-tuned models on the same data — single-digit to low-double-digit at most, not >100. Spikes of 70× in a single step shouldn't appear without an obvious data-side trigger.

Cross-framework: same pattern in Unsloth and a custom FSDP-2 setup

Independent reproductions of the same bug:

Unsloth FastModel.get_peft_model (single GPU, bf16, same dataset, LoRA r=16/α=32):

{'loss': 1.46,  'grad_norm':    14.98}
{'loss': 1.28,  'grad_norm':     9.78}
{'loss': 1.60,  'grad_norm':     7.64}
{'loss': 1.14,  'grad_norm':    92.10}   <-- spike #1 at step 4
{'loss': 1.54,  'grad_norm':    29.00}
{'loss': 1.61,  'grad_norm':     5.21}
{'loss': 1.36,  'grad_norm':     7.83}
{'loss': 1.36,  'grad_norm':     7.79}
{'loss': 1.51,  'grad_norm': 11140.00}   <-- spike #2 at step 9 — 1000× neighbors
{'loss': 1.27,  'grad_norm':    15.06}

Note step 9: the loss is unremarkable (1.51), but the pre-clip grad norm is 11,140 — three orders of magnitude above its neighbors.

Custom FSDP-2 (TorchTitan-style, 8×H200) with attn_implementation ∈ {flash_attention_2, flash_attention_3, sdpa}:

attn=fa2  step  0: grad_norm = 326    step 7:  63
attn=fa3  step  0: grad_norm = 332
attn=sdpa step  0: grad_norm = 9.94   step 3: 1040  step 12: 26

The kernel choice shifts where the spikes happen but never eliminates them. This consistency across stacks/kernels is what makes me think the issue is in the Gemma 4 modeling code itself, not any one trainer.

Loudest pre-clip grad norm seen across all stacks (same model, similar LoRA, same scale of data)

SetupLoudest grad_normSpike stepLoss at that step
HF Trainer + PEFT (this repro)31282.79
Unsloth FastModel.get_peft_model11,14091.51
Custom FSDP-2 (TorchTitan-style, FA2)32604.62
Custom FSDP-2 (TorchTitan-style, SDPA)1,04034.66
Custom FSDP-2 (TorchTitan-style, FA3)33204.62

Healthy comparable models on this dataset (Qwen3-32B, Llama-3.3-70B) stay sub-1.0 throughout.

What I checked / ruled out

  • Not the final_logit_softcapping Gemma quirkGemma4Config correctly exposes it via text_config.final_logit_softcapping=30.0 and the model's forward applies it.
  • Not use_cache=True interfering — explicitly set to False, plus gradient checkpointing.
  • Not QLoRA dequantization noise — repro uses bf16 base, not 4-bit.
  • Not data quality — same dataset (FineTome-100k) and the same LoRA hyperparameters train cleanly on Qwen3-32B and Llama-3.3-70B.
  • Not the chat template — repro uses the model's native template shipped with google/gemma-4-31B-it.
  • Not the kernel — repro uses SDPA. We separately verified FA2 and FA3 produce the same (sometimes worse) volatility.
  • Not single-GPU vs multi-GPU — appears identically on one H200 (this repro), and on 8×H200 FSDP-2.
  • Not framework-specific — appears in HF Trainer + PEFT (this repro), in Unsloth FastModel, and in a custom FSDP-2 trainer.

Where I dug into

Per-layer backward-pass instrumentation on the 31B model (registering register_full_backward_hook on every Gemma4TextDecoderLayer and recording grad_input / grad_output magnitudes) shows:

  • The gradient flowing backward is roughly flat from layer 59 down to layer 33 (sub-10k magnitudes, no growth).

  • A sharp 2.9–4.8× amplification kicks in at three specific sliding-attention layers (L33, L32, L27), all of which use the standard Gemma 4 attention path: head_dim=256, scaling=1.0, Q-RMSNorm + K-RMSNorm, GQA 2:1, sliding window 1024.

  • Drilling further with register_full_backward_hook on submodules inside those layers, the proximate cause is an asymmetry between dK and dV:

    Layerk_norm.grad_outputv_norm.grad_outputdK / dV
    L27663,55526,29725×
    L32124,3981,67174×
    L3321,03725283×

    In standard attention |dK| ≈ |dV|. Seeing dK / dV = 25–83× at multiple layers consistently is structurally unusual.

Possible causes I haven't been able to localize:

  1. The Q-RMSNorm / K-RMSNorm + scaling=1.0 interaction may produce dK gradients that aren't being scaled correctly in backward. Other models (e.g. Qwen3) use scaling=1/√d and pre-scale Q/K in attention; that gives a different backward gradient profile.
  2. GQA (2:1) repeat_kv backward summation may be over-accumulating dK for some inputs.
  3. The layer_scalar per-layer scaling buffer (which is unique to Gemma 4) may interact oddly with backward through residual + attention.

I'm happy to share the per-layer / per-submodule grad traces (JSON) and the offending input samples on request.

Environment

torch                2.11.0+cu130
transformers         5.6.2
peft                 0.19.1
datasets             4.8.5
accelerate           1.13.0
flash-attn           (not required; repro uses SDPA)
GPU                  H200 143 GB (CUDA_VISIBLE_DEVICES=0; also tested on 8×H200)
CUDA                 13.0
Python               3.12.3

Happy to test patches against this repro and report back. The whole training loop runs in ~14 minutes on one H200, so iteration is fast.

Who can help?

@Cyrilvallez @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

"""Minimal HF transformers + PEFT reproducer for the Gemma 4 31B exploding
gradient-norm issue.

Dependencies (single, separate venv):
    pip install "torch>=2.6.0" transformers peft accelerate datasets bitsandbytes

Run on one H200/A100 (≥80 GB recommended):
    CUDA_VISIBLE_DEVICES=0 python repro_gemma4_grad_spike_hf_peft.py
"""

from __future__ import annotations

import os
import json

import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForImageTextToText,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)


# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_ID = "google/gemma-4-31B-it"
DATASET_ID = "mlabonne/FineTome-100k"
DATASET_SLICE = "train[:2000]"
MAX_SEQ_LEN = 4096
OUTPUT_DIR = "outputs/hf_peft_grad_spike_repro"

LORA_R = 16
LORA_ALPHA = 32

PER_DEVICE_BATCH = 1
GRAD_ACCUM = 8
LR = 1e-4
WARMUP_STEPS = 5
MAX_STEPS = 30
WEIGHT_DECAY = 0.0
MAX_GRAD_NORM = 1.0


# Disable math-SDPA (OOMs on Gemma 4's head_dim=512 globals); force mem-efficient.
torch.backends.cuda.enable_math_sdp(False)
import transformers.integrations.sdpa_attention as _sdpa_mod  # noqa: E402
_sdpa_mod.use_gqa_in_sdpa = lambda *a, **kw: False


SHAREGPT_TO_OAI_ROLE = {
    "human": "user", "user": "user", "input": "user",
    "gpt": "assistant", "assistant": "assistant", "output": "assistant",
    "system": "system",
}


def to_oai(conv):
    out = []
    for msg in conv:
        role_raw = msg.get("from") or msg.get("role")
        content = msg.get("value") or msg.get("content")
        role = SHAREGPT_TO_OAI_ROLE.get(role_raw, "user")
        if content is None:
            continue
        out.append({"role": role, "content": str(content)})
    while out and out[0]["role"] not in ("system", "user"):
        out.pop(0)
    return out


def main():
    print(f"Loading tokenizer + model: {MODEL_ID}")
    tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    model = AutoModelForImageTextToText.from_pretrained(
        MODEL_ID,
        dtype=torch.bfloat16,
        attn_implementation="sdpa",
        trust_remote_code=True,
    )
    model.config.use_cache = False
    model.enable_input_require_grads()

    print(f"\nApplying LoRA: r={LORA_R} alpha={LORA_ALPHA}")
    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=0.0,
        target_modules=r".*language_model\..*\.(?:q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)",
        task_type="CAUSAL_LM",
        bias="none",
    )
    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()
    model.print_trainable_parameters()

    print(f"\nLoading dataset: {DATASET_ID} ({DATASET_SLICE})")
    ds = load_dataset(DATASET_ID, split=DATASET_SLICE)

    def tokenize_row(example):
        conv = example.get("conversations") or example.get("messages") or []
        messages = to_oai(conv)
        if not messages:
            return None
        try:
            enc = tok.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=False,
                truncation=True, max_length=MAX_SEQ_LEN,
                padding="max_length", return_dict=True,
            )
        except Exception:
            return None
        input_ids = enc["input_ids"]
        attn_mask = enc["attention_mask"]
        labels = [(-100 if m == 0 else x) for x, m in zip(input_ids, attn_mask)]
        return {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}

    print("Tokenizing...")
    ds = ds.map(tokenize_row, remove_columns=ds.column_names, num_proc=4)
    ds = ds.filter(lambda x: x.get("input_ids") is not None)

    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=PER_DEVICE_BATCH,
        gradient_accumulation_steps=GRAD_ACCUM,
        learning_rate=LR,
        lr_scheduler_type="cosine",
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,
        logging_steps=1,
        bf16=True,
        optim="adamw_torch",
        weight_decay=WEIGHT_DECAY,
        max_grad_norm=MAX_GRAD_NORM,
        save_strategy="no",
        report_to="none",
        seed=42,
        gradient_checkpointing=False,
        remove_unused_columns=False,
    )

    trainer = Trainer(model=model, args=args, train_dataset=ds)
    stats = trainer.train()

    history = trainer.state.log_history
    log_file = os.path.join(OUTPUT_DIR, "trainer_log_history.json")
    with open(log_file, "w") as f:
        json.dump(history, f, indent=2)

    print("\nSTEP   LOSS         GRAD_NORM    LR")
    for entry in history:
        if "loss" in entry and "grad_norm" in entry:
            print(f"{entry.get('step', '?'):>4}   {entry['loss']:<10.4f}   "
                  f"{entry['grad_norm']:<11.4f}  {entry.get('learning_rate', 0):.2e}")
    print(f"\nTraining runtime: {stats.metrics.get('train_runtime', 0):.1f}s")


if __name__ == "__main__":
    main()

Expected behavior

Pre-clip gradient norms during LoRA fine-tuning of gemma-4-31B-it should be in the same range as similarly-sized open instruction-tuned models on the same data — i.e. sub-1.0 in the first few steps and below ~5 throughout, with no isolated spikes orders of magnitude above neighbouring steps.

What I see instead:

  • cold-start steps 1–5 with pre-clip grad_norm in the 11–30 range
  • isolated mid-training spike of 312 (HF + PEFT) / 11,140 (Unsloth) / 1,040 (FSDP-2) at a step where the loss itself is unremarkable
  • the spike step's neighbours are all in single digits, so it's not a gradual buildup — it's a single-batch-driven outlier in the backward pass

The clipping at max_grad_norm=1.0 means actual updates stay bounded, but a "raw" pre-clip norm that swings by 70–1000× in a single step on a healthy model is not what I'd expect.

extent analysis

TL;DR

The most likely fix for the exploding gradient norm issue in the Gemma 4 31B model is to adjust the scaling of the attention weights or the layer scalar buffer to prevent the large gradient spikes.

Guidance

  1. Investigate the Q-RMSNorm and K-RMSNorm interaction: The unusual gradient profile might be caused by the interaction between Q-RMSNorm and K-RMSNorm with scaling=1.0. Consider using scaling=1/√d as in other models.
  2. Check the GQA (2:1) repeat_kv backward summation: The GQA mechanism might be over-accumulating dK gradients for certain inputs, leading to the large spikes.
  3. Verify the layer scalar buffer interaction: The unique layer scalar buffer in Gemma 4 might be interacting oddly with the backward pass, causing the gradient spikes.
  4. Test with different attention implementations: Try using different attention implementations, such as flash attention, to see if the issue persists.

Example

No specific code example is provided, as the issue is likely related to the model architecture or configuration. However, the provided reproduction script can be used to test potential fixes.

Notes

The issue seems to be specific to the Gemma 4 31B model and is not present in other similar models. The large gradient spikes are not caused by obvious data-side issues, and the clipping at max_grad_norm=1.0 keeps actual updates bounded.

Recommendation

Apply a workaround by adjusting the scaling of the attention weights or the layer scalar buffer to prevent the large gradient spikes. This can be done by modifying the LoraConfig or the Gemma4Config to use a different scaling factor or by implementing a custom gradient clipping mechanism.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

Pre-clip gradient norms during LoRA fine-tuning of gemma-4-31B-it should be in the same range as similarly-sized open instruction-tuned models on the same data — i.e. sub-1.0 in the first few steps and below ~5 throughout, with no isolated spikes orders of magnitude above neighbouring steps.

What I see instead:

  • cold-start steps 1–5 with pre-clip grad_norm in the 11–30 range
  • isolated mid-training spike of 312 (HF + PEFT) / 11,140 (Unsloth) / 1,040 (FSDP-2) at a step where the loss itself is unremarkable
  • the spike step's neighbours are all in single digits, so it's not a gradual buildup — it's a single-batch-driven outlier in the backward pass

The clipping at max_grad_norm=1.0 means actual updates stay bounded, but a "raw" pre-clip norm that swings by 70–1000× in a single step on a healthy model is not what I'd expect.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - 💡(How to fix) Fix Gemma 4: Exploding pre-clip gradient norms during LoRA fine-tuning of `gemma-4-31B-it` [2 comments, 2 participants]