transformers - 💡(How to fix) Fix Nemotron-3-Nano-Omni: supports_gradient_checkpointing flag missing on trust_remote_code variant (1-line fix)

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…

Error Message

from transformers import AutoModelForCausalLM import torch

model = AutoModelForCausalLM.from_pretrained( "nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda:0", )

model.gradient_checkpointing_enable()

→ ValueError: NemotronHForCausalLM does not support gradient checkpointing.

Root Cause

Practical impact — LoRA fine-tuning the 30B-A3B Omni model on a single 96GB GPU OOMs at otherwise-reasonable settings because we can't checkpoint activations:

Fix Action

Fix / Workaround

Verified locally by applying the patch to my cached trust_remote_code copy — gradient_checkpointing_enable() then succeeds, propagates self.gradient_checkpointing = True to each NemotronHBlock, and KTO at max_length=384/batch_size=2 no longer OOMs.

Code Example

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
)

model.gradient_checkpointing_enable()
# → ValueError: NemotronHForCausalLM does not support gradient checkpointing.

---

class NemotronHBlock(GradientCheckpointingLayer):
    ...

---

class NemotronHPreTrainedModel(PreTrainedModel):
     config: NemotronHConfig
     base_model_prefix = "backbone"
+    supports_gradient_checkpointing = True
     _no_split_modules = ["NemotronHBlock"]
     _skip_keys_device_placement = ["past_key_values"]
     _supports_flash_attn = True
RAW_BUFFERClick to expand / collapse

System Info

  • transformers version: 5.8.0
  • Platform: Linux-6.17.0-22-generic-x86_64-with-glibc2.39
  • Python version: 3.10.19
  • PyTorch version (GPU?): 2.10.0+cu128 (cuda 12.8)
  • Huggingface_hub version: 0.36.2
  • Safetensors version: 0.6.2
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • PEFT version: 0.18.1
  • TRL version: 0.29.1
  • bitsandbytes version: 0.49.2
  • GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition (96 GB)
  • Model: nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16 (also affects -FP8 and -NVFP4)
  • Loaded via: trust_remote_code=True

Who can help?

@ArthurZucker @Cyrilvallez (text models / trust_remote_code / model loading) @SunMarc (Trainer / fits with this) — gradient checkpointing affects DPO/KTO/SFT training @BenjaminBossan @githubnemo (PEFT — affects LoRA fine-tuning users)

cc: NVIDIA team maintaining the Nemotron-3-Nano-Omni model repos

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Loading the model via trust_remote_code=True and trying to enable gradient checkpointing raises a ValueError, even though the block-level machinery (NemotronHBlock(GradientCheckpointingLayer)) is already in place.

Minimal repro:

from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "nvidia/Nemotron-3-Nano-Omni-30B-A3B-Reasoning-BF16",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
)

model.gradient_checkpointing_enable()
# → ValueError: NemotronHForCausalLM does not support gradient checkpointing.

Same failure if you instead set gradient_checkpointing=True in TrainingArguments, DPOConfig, KTOConfig, or SFTConfig. The same failure happens on the -FP8 and -NVFP4 variants which ship the same modeling_nemotron_h.py.

Practical impact — LoRA fine-tuning the 30B-A3B Omni model on a single 96GB GPU OOMs at otherwise-reasonable settings because we can't checkpoint activations:

  • DPO at max_length=512, batch_size=1, LoRA rank 32 (attn + MLP) → OOM
  • KTO at max_length=384, batch_size=2 (KTO requires batch ≥ 2) → OOM at step 1 backward

Root cause: the trust_remote_code modeling_nemotron_h.py shipped with NVIDIA's Omni model repos is missing the supports_gradient_checkpointing = True class attribute on NemotronHPreTrainedModel. The canonical transformers/models/nemotron_h/modeling_nemotron_h.py in this repo does have the flag (around line 953). The Omni trust_remote_code copy diverged.

The infra is already there — line 985 of the trust_remote_code modeling file:

class NemotronHBlock(GradientCheckpointingLayer):
    ...

So if we just set the class flag, GradientCheckpointingLayer does the rest (auto-wraps the block's forward in torch.utils.checkpoint.checkpoint(...) when self.gradient_checkpointing = True).

One-line diff (in NVIDIA's Omni repo modeling_nemotron_h.py):

 class NemotronHPreTrainedModel(PreTrainedModel):
     config: NemotronHConfig
     base_model_prefix = "backbone"
+    supports_gradient_checkpointing = True
     _no_split_modules = ["NemotronHBlock"]
     _skip_keys_device_placement = ["past_key_values"]
     _supports_flash_attn = True

This matches what the canonical transformers version of NemotronHPreTrainedModel already sets in this repo.

Verified locally by applying the patch to my cached trust_remote_code copy — gradient_checkpointing_enable() then succeeds, propagates self.gradient_checkpointing = True to each NemotronHBlock, and KTO at max_length=384/batch_size=2 no longer OOMs.

Affected model repos:

Since this is a trust_remote_code modeling file, the fix needs to land in the model repos (not this transformers repo). Filing here because (a) the canonical implementation is here and is correct, so this is a divergence bug, and (b) NVIDIA's team that ships these repos is active in this repo's PR thread.

Expected behavior

model.gradient_checkpointing_enable() should succeed (matching the behavior of the canonical transformers/models/nemotron_h implementation) and propagate self.gradient_checkpointing = True to each NemotronHBlock. The blocks already inherit from GradientCheckpointingLayer, so after the flag is set the standard HF gradient-checkpointing path activates: forward passes are wrapped in torch.utils.checkpoint.checkpoint(...), activations are recomputed during backward, peak memory drops ~30-50%.

Concretely, this unblocks LoRA fine-tuning the 30B-A3B Omni model on a single 96GB GPU at sensible defaults (max_length up to 512, batch_size up to 2 for KTO).

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

model.gradient_checkpointing_enable() should succeed (matching the behavior of the canonical transformers/models/nemotron_h implementation) and propagate self.gradient_checkpointing = True to each NemotronHBlock. The blocks already inherit from GradientCheckpointingLayer, so after the flag is set the standard HF gradient-checkpointing path activates: forward passes are wrapped in torch.utils.checkpoint.checkpoint(...), activations are recomputed during backward, peak memory drops ~30-50%.

Concretely, this unblocks LoRA fine-tuning the 30B-A3B Omni model on a single 96GB GPU at sensible defaults (max_length up to 512, batch_size up to 2 for KTO).

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - 💡(How to fix) Fix Nemotron-3-Nano-Omni: supports_gradient_checkpointing flag missing on trust_remote_code variant (1-line fix)