transformers - 💡(How to fix) Fix FSDP + KD-teacher-wrap + flash-attn-2 + LoRA: working-set memory exceeds 40 GiB per rank on production-shape SFT training (deepseek-coder-6.7b + sahil2801/CodeAlpaca-20k) [1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45941Fetched 2026-05-14 03:28:39
View on GitHub
Comments
1
Participants
2
Timeline
2
Reactions
0
Timeline (top)
closed ×1commented ×1

We're running supervised fine-tuning of deepseek-ai/deepseek-coder-6.7b-instruct with LoRA (PEFT) + Knowledge Distillation (teacher = same architecture as student, full bf16 replica), via the HF Trainer with FSDP full_shard integration, attn_implementation="flash_attention_2", bf16, world_size=8, on AWS SageMaker HuggingFace DLC. The training-step-0 forward pass exhausts per-rank GPU memory on every evaluated hardware class, including A100 40 GiB (ml.p4d.24xlarge × 8 ranks). After 23 attempts spanning two DLC versions, three hardware classes, and progressive algorithm tunings (FSDP wrap of both student AND teacher, flash-attention-2 with sdpa fallback, non-reentrant gradient checkpointing, bf16 + bf16_full_eval, PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, FSDP-owned activation_checkpointing), the production-shape working set on FSDP world_size=8 still exceeds ~40 GiB per rank.

We'd like HuggingFace's input on the recommended FSDP integration pattern for KD-distillation training where both student and teacher need to be sharded coherently — and whether DeepSpeed ZeRO-3 (with HF Trainer's deepspeed= integration) is the recommended fallback for this class of recipe.

Error Message

| A10G (ml.g5.48xlarge) | 86 | 22.30 GiB | 8 | Non-deterministic SIGSEGV (exitcode -11) at training-step-0, no Python traceback, ~5 s after Starting causal LM training... (Appendices R/S); reproduced on different ranks across cycles |

Root Cause

The dominant peak is the FSDP-wrapped teacher's pre-forward all-gather stacked with optimizer-state initialisation that hasn't yet sharded. FSDP's optim_state_dict is built at the end of step-0; the first forward+backward pays full-replica overhead on optimizer state before the shard rebalances. This is a documented FSDP first-step memory peak, but it isn't commonly hit because most recipes don't simultaneously hold a KD-teacher full-replica all-gather alongside it.

Fix Action

Fix / Workaround

HardwareSMPer-rank capacityWorld sizeResult
A10G (ml.g5.12xlarge)8622.30 GiB4OOM at KD-teacher load (Appendix L); resolved by FSDP-wrap of teacher → OOM at eager-softmax (M); resolved by flash-attn-2 → OOM at autocast convert_to_fp32 (N); resolved by autocast tuning → working-set ceiling at ~21 GiB allocated (O)
A10G (ml.g5.48xlarge)8622.30 GiB8Non-deterministic SIGSEGV (exitcode -11) at training-step-0, no Python traceback, ~5 s after Starting causal LM training... (Appendices R/S); reproduced on different ranks across cycles
A10G (ml.g5.48xlarge)8622.30 GiB8Pre-NCCL CUDA OOM on all 8 ranks during AutoModelForCausalLM.from_pretrained of the teacher when fsdp_wrap_teacher=False is set (Appendix T) — accelerate.set_module_tensor_to_device 2× transient + pre-FSDP student replica + flash-attn workspace + CUDA_LAUNCH_BLOCKING=1 async-coalescing inhibition all sum past 22.30 GiB
A100 (ml.p4d.24xlarge)8039.49 GiB8torch.OutOfMemoryError on rank 3 at step-0 forward, ~39.17 GiB / 39.49 GiB consumed (35.14 GiB allocated by PyTorch + 2.84 GiB reserved-but-unallocated); ranks 6+7 also faulted within the same 344 MiB step-0 forward allocation window. SIGSEGV class from A10G implicitly disproven — the same algorithm-side code path ran clean on A100 SM 80 through Downloading → Training before OOM'ing on memory pressure.
  • huggingface-pytorch-training:2.5.1-transformers4.49.0-gpu-py311-cu124-ubuntu22.04-v1.0 (torch-2.4, NCCL ~2.20, cu124) — the A10G SIGSEGV case (Appendices R/S).
  • huggingface-pytorch-training:2.8.0-transformers4.56.2-gpu-py312-cu129-ubuntu22.04-v1.1 (torch-2.8, NCCL ~2.24+, cu129) — the A10G TypeError on evaluation_strategy (Appendix U; resolved by rename to eval_strategy); the A100 working-set OOM (Appendix V).

Workaround in use

RAW_BUFFERClick to expand / collapse

Summary

We're running supervised fine-tuning of deepseek-ai/deepseek-coder-6.7b-instruct with LoRA (PEFT) + Knowledge Distillation (teacher = same architecture as student, full bf16 replica), via the HF Trainer with FSDP full_shard integration, attn_implementation="flash_attention_2", bf16, world_size=8, on AWS SageMaker HuggingFace DLC. The training-step-0 forward pass exhausts per-rank GPU memory on every evaluated hardware class, including A100 40 GiB (ml.p4d.24xlarge × 8 ranks). After 23 attempts spanning two DLC versions, three hardware classes, and progressive algorithm tunings (FSDP wrap of both student AND teacher, flash-attention-2 with sdpa fallback, non-reentrant gradient checkpointing, bf16 + bf16_full_eval, PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, FSDP-owned activation_checkpointing), the production-shape working set on FSDP world_size=8 still exceeds ~40 GiB per rank.

We'd like HuggingFace's input on the recommended FSDP integration pattern for KD-distillation training where both student and teacher need to be sharded coherently — and whether DeepSpeed ZeRO-3 (with HF Trainer's deepspeed= integration) is the recommended fallback for this class of recipe.

Reproducer

The recipe is documented across two private repositories:

  • config/training/codex_02_sft.yaml (per-model hyperparameters: batch_size=4, max_length=4096, alpha_kd=0.7, alpha_attention=0.2, alpha_hidden=0.1, fsdp_enabled=true, fsdp_transformer_layer_cls_to_wrap=LlamaDecoderLayer, fsdp_wrap_teacher=true, attn_implementation=flash_attention_2, dataset = sahil2801/CodeAlpaca-20k via HF Hub).
  • d3n_training/entry_points/sagemaker_entry_v2.py::train_causal_lm — builds the TrainingArguments with fsdp="full_shard auto_wrap", fsdp_config={"transformer_layer_cls_to_wrap": "LlamaDecoderLayer"}, bf16=True, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, eval_strategy="no", then instantiates a custom KDTrainer (subclass of Trainer) that loads a teacher model via transformers.AutoModelForCausalLM.from_pretrained(..., torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") and wraps it with torch.distributed.fsdp.FullyShardedDataParallel(...) before the first training step.
  • d3n_training/training/distillation.py::load_teacher_model and fsdp_wrap_teacher_model — the teacher load + FSDP wrap path.

The training loop is a standard Trainer.train() call; no custom collator or sampler.

Hardware tested

HardwareSMPer-rank capacityWorld sizeResult
A10G (ml.g5.12xlarge)8622.30 GiB4OOM at KD-teacher load (Appendix L); resolved by FSDP-wrap of teacher → OOM at eager-softmax (M); resolved by flash-attn-2 → OOM at autocast convert_to_fp32 (N); resolved by autocast tuning → working-set ceiling at ~21 GiB allocated (O)
A10G (ml.g5.48xlarge)8622.30 GiB8Non-deterministic SIGSEGV (exitcode -11) at training-step-0, no Python traceback, ~5 s after Starting causal LM training... (Appendices R/S); reproduced on different ranks across cycles
A10G (ml.g5.48xlarge)8622.30 GiB8Pre-NCCL CUDA OOM on all 8 ranks during AutoModelForCausalLM.from_pretrained of the teacher when fsdp_wrap_teacher=False is set (Appendix T) — accelerate.set_module_tensor_to_device 2× transient + pre-FSDP student replica + flash-attn workspace + CUDA_LAUNCH_BLOCKING=1 async-coalescing inhibition all sum past 22.30 GiB
A100 (ml.p4d.24xlarge)8039.49 GiB8torch.OutOfMemoryError on rank 3 at step-0 forward, ~39.17 GiB / 39.49 GiB consumed (35.14 GiB allocated by PyTorch + 2.84 GiB reserved-but-unallocated); ranks 6+7 also faulted within the same 344 MiB step-0 forward allocation window. SIGSEGV class from A10G implicitly disproven — the same algorithm-side code path ran clean on A100 SM 80 through Downloading → Training before OOM'ing on memory pressure.

DLC versions tested

  • huggingface-pytorch-training:2.5.1-transformers4.49.0-gpu-py311-cu124-ubuntu22.04-v1.0 (torch-2.4, NCCL ~2.20, cu124) — the A10G SIGSEGV case (Appendices R/S).
  • huggingface-pytorch-training:2.8.0-transformers4.56.2-gpu-py312-cu129-ubuntu22.04-v1.1 (torch-2.8, NCCL ~2.24+, cu129) — the A10G TypeError on evaluation_strategy (Appendix U; resolved by rename to eval_strategy); the A100 working-set OOM (Appendix V).

Both DLCs hit a memory ceiling at step-0 forward when paired with the production-shape recipe; the new DLC clears the SIGSEGV class on A100 (and likely on A10G too, untested).

Memory budget analysis (per-rank, Appendix V verbatim)

Per-rank working-set components at step-0 forward on A100 40 GiB × 8:

ComponentMemory
Student FSDP shard (6.7B × bf16 / 8 ranks)~1.7 GiB
Teacher FSDP all-gather buffer (transient peak during KD forward)~13.4 GiB
Per-layer activations (B=4, S=4096, H=4096, 32 layers, flash-attn-2 + non-reentrant gradient-checkpointing)~2-5 GiB
Autocast / loss-path transients~1-4 GiB
Optimizer-state slabs (Accelerator.prepare; not yet sharded on step-0 first pass)~3-6 GiB
Gradient buffers (in-flight backward)~2-4 GiB
NCCL/CUDA pinned + flash-attn workspace + cuBLAS handles~1-3 GiB
Observed total~35.14 GiB allocated + 2.84 GiB reserved-but-unallocated → OOM at 344 MiB step-0 forward intermediate

The dominant peak is the FSDP-wrapped teacher's pre-forward all-gather stacked with optimizer-state initialisation that hasn't yet sharded. FSDP's optim_state_dict is built at the end of step-0; the first forward+backward pays full-replica overhead on optimizer state before the shard rebalances. This is a documented FSDP first-step memory peak, but it isn't commonly hit because most recipes don't simultaneously hold a KD-teacher full-replica all-gather alongside it.

Question for HuggingFace

  1. Is there a recommended Trainer + FSDP integration pattern for KD-distillation training that shards BOTH student AND teacher coherently — e.g., a single composite FSDP unit wrapping both models so the teacher's all-gather is sized down by world_size?
  2. Alternatively, is HF Trainer's deepspeed= integration with ZeRO-3 (with optimizer-state CPU offload) the recommended pattern for this class of recipe? Are there documented gotchas with PEFT + custom data-collators + flash-attn-2 + ZeRO-3 that we should anticipate before doing the migration?
  3. The reduced-shape config (batch_size=1, max_length=2048, KD off, aux-heads off) on a single A10G via DDP + LoRA + non-reentrant gradient checkpointing runs cleanly on the same DLC + tarball. We're confident the recipe and chain are correct under HF's standard patterns; we want HF's input on the FSDP × KD-teacher × LoRA × flash-attn-2 × bf16 working-set ceiling we're hitting.

What we tried (Appendices L → V, in order, with verbatim CloudWatch evidence in the runbook)

  1. FSDP plumbing in HF Trainer via fsdp="full_shard auto_wrap" + fsdp_config={"transformer_layer_cls_to_wrap": "LlamaDecoderLayer"} — partially resolved.
  2. Manual FSDP wrap of the KD teacher (path b'' — necessary because Trainer's FSDP integration shards only the trainer's student model, leaving the teacher as a full bf16 replica per rank).
  3. attn_implementation="flash_attention_2" on both student and teacher (with sdpa fallback) — resolved the eager-softmax dtype=torch.float32 8 GiB activation OOM.
  4. PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True — neutralised allocator fragmentation (2.21 → 0.47 GiB tail).
  5. TrainingArguments(bf16_full_eval=True) + AutocastKwargs(enabled=False) — broke PEFT fp32-LoRA × bf16-base dtype mixing; reverted.
  6. FSDP-owned activation_checkpointing (post-PR migration from Trainer.gradient_checkpointing path) + gradient_checkpointing_kwargs={"use_reentrant": False}.
  7. Hardware pivot A10G → A100 (ml.g5.48xlargeml.p4d.24xlarge).
  8. DLC upgrade transformers 4.49.0 → 4.56.2 (caught evaluation_strategyeval_strategy rename).

All cycles documented in the public Stage-3 runbook (private repo, but the CloudWatch traces are reproducible from the verbatim FailureReason strings quoted above).

Workaround in use

We've switched the codex-02 SFT cycle to a reduced-shape config (KD disabled, no aux-heads, batch_size=1, max_length=2048, single-A10G DDP + LoRA) — a valid HF Trainer cycle that completes cleanly and exercises the full SageMaker training → merge → IC rotation chain end-to-end. This loses the representation-learning signal the production recipe was designed around (the KD teacher and aux-heads), so we're treating this as a partial production rollout while we sort out the working-set ceiling for the full recipe.

Asks

  • Any pointers to a working HF Trainer + FSDP + KD-teacher pattern we can compare against.
  • Confirmation (or rebuttal) of the FSDP first-step optimizer-state memory peak hypothesis.
  • Guidance on whether to migrate to deepspeed= ZeRO-3 for this recipe class.

Thanks!

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - 💡(How to fix) Fix FSDP + KD-teacher-wrap + flash-attn-2 + LoRA: working-set memory exceeds 40 GiB per rank on production-shape SFT training (deepseek-coder-6.7b + sahil2801/CodeAlpaca-20k) [1 comments, 2 participants]