transformers - 💡(How to fix) Fix [Bug] Catastrophic gradient explosion (NaN) in RLHF with Qwen3.5 due to 3D position_ids forcing SDPA Math fallback and BF16 collapse [3 comments, 3 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44928Fetched 2026-04-08 01:16:58
View on GitHub
Comments
3
Participants
3
Timeline
12
Reactions
0
Timeline (top)
mentioned ×4subscribed ×4commented ×3labeled ×1

Error Message

  1. Explicit Warning or Fallback to FA2: Until a native SDPA varlen solution is implemented, transformers should aggressively warn users when sdpa is initialized alongside padding masks on models requiring dense mask materialization. Currently, explicitly setting attn_implementation="flash_attention_2" is the only mathematically safe approach, as Qwen2FlashAttention2 uses cu_seqlens to physically drop padding and leverages FP32 registers internally, perfectly stabilizing the RLHF gradients.

Fix Action

Fix / Workaround

  1. Avoid Silent Fallbacks to the Math Backend: When training Qwen3.5 models (or any Qwen2 architecture handling 3D position_ids/mRoPE), the transformers implementation explicitly materializes a massive 4D Dense Mask ([Batch, 1, SeqLen, SeqLen]) and sets is_causal=False. This design explicitly violates PyTorch SDPA’s fused kernel constraints (if (attn_mask.has_value()) { return false; }), silently forcing a downgrade to the Math backend. Expected behavior: The implementation should decouple mRoPE coordinate handling from the attention mask generation, preserving the ability to rely on the implicit is_causal=True mechanism, which keeps the highly optimized FlashAttention kernel engaged.

Code Example

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 audit_qwen_hf.py

---

============================================================
🔍 [源码级审计] Qwen3.5 (SDPA 模式) 到底传了什么给底层?
============================================================
➡️ 模拟 DAPO 训练输入,序列长度: 8192

🚨 [劫持成功] Qwen3.5 正在调用 PyTorch SDPA!
   Query 形状: torch.Size([1, 8, 8192, 128])
   ⚠️ Qwen 传来的 Mask 形状: torch.Size([1, 1, 8192, 8192])
   ⚠️ Mask 的数据类型: torch.bool
   💀 仅这个 Mask 矩阵就会占据显存: 0.125000 GB
   is_causal 参数: False

---

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 train_dapo_3_debug.py  --use_lora --use_4bit

---

……
💥 [爆点定位] 梯度在经过 【model.layers.30.linear_attn.norm】 的反向计算后瞬间爆炸!
   传出梯度最大值: 26240.0
⚠️ [数值爆炸预警]: model.layers.27.self_attn.v_proj | 梯度 Max: 17280.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.27.self_attn.k_norm | 梯度 Max: 135168.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.27.self_attn.k_proj | 梯度 Max: 111616.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.27.self_attn.q_norm | 梯度 Max: 126976.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.27.self_attn.q_proj | 梯度 Max: 97280.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn | 梯度 Max: 67072.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.o_proj | 梯度 Max: 67072.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.v_proj | 梯度 Max: 13172736.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.k_norm | 梯度 Max: 62390272.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.k_proj | 梯度 Max: 45613056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.q_norm | 梯度 Max: 297795584.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.23.self_attn.q_proj | 梯度 Max: 242221056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn | 梯度 Max: 132120576.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.o_proj | 梯度 Max: 132120576.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.v_proj | 梯度 Max: 48103633715200.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.k_norm | 梯度 Max: 68444598829056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.k_proj | 梯度 Max: 54150947667968.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.q_norm | 梯度 Max: 141836999983104.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.19.self_attn.q_proj | 梯度 Max: 204509162766336.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn | 梯度 Max: 131391639519232.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.o_proj | 梯度 Max: 131391639519232.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.v_proj | 梯度 Max: 1549526502191602335744.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.k_norm | 梯度 Max: 1752440687002407403520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.k_proj | 梯度 Max: 2822351843277561397248.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.q_norm | 梯度 Max: 2600990914393046777856.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.15.self_attn.q_proj | 梯度 Max: 5902958103587056517120.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn | 梯度 Max: 2951479051793528258560.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.o_proj | 梯度 Max: 2951479051793528258560.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.v_proj | 梯度 Max: 4584246707978673830485819392.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.k_norm | 梯度 Max: 7969239002899635519663112192.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.k_proj | 梯度 Max: 14932651723879899565970685952.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.q_norm | 梯度 Max: 40852021296417549071671099392.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.11.self_attn.q_proj | 梯度 Max: 41470991316060239209120661504.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn | 梯度 Max: 52612451669628661683212779520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.o_proj | 梯度 Max: 52612451669628661683212779520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.v_proj | 梯度 Max: 7382797095729208034316799468109824.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.k_norm | 梯度 Max: 29206669829258405410484041851863040.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.k_proj | 梯度 Max: 50949412924372996104955495230472192.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.q_norm | 梯度 Max: 55168154121932543553136523497963520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.7.self_attn.q_proj | 梯度 Max: 113581493780449354374104607201689600.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.3.self_attn | 梯度 Max: 138893940965806639063190776806637568.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警]: model.layers.3.self_attn.o_proj | 梯度 Max: 138893940965806639063190776806637568.00 | Dtype: torch.bfloat16
……

---

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 train_dapo_3_debug.py  --use_lora --use_4bit  --use_flash_attn

---

……
[🔍 22:47:55]Step 1 梯度全部正常 (最大 =0.003845 @ base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight)
……
RAW_BUFFERClick to expand / collapse

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 5.3.0
  • Platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
  • Python version: 3.11.15
  • Huggingface_hub version: 1.7.1
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: not found
  • DeepSpeed version: 0.18.8
  • PyTorch version (accelerator?): 2.10.0+cu128 (CUDA)
  • Using distributed or parallel set-up in script?: <fill in>
  • Using GPU in script?: <fill in>
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

We have fully isolated the issue and provided a reproducible repository here: 👉 https://github.com/ouroborosscr/Report-the-gradient-explosion-of-qwen3.5

1.Verify Dapo degradation

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 audit_qwen_hf.py

Output:

============================================================
🔍 [源码级审计] Qwen3.5 (SDPA 模式) 到底传了什么给底层?
============================================================
➡️ 模拟 DAPO 训练输入,序列长度: 8192

🚨 [劫持成功] Qwen3.5 正在调用 PyTorch SDPA!
   Query 形状: torch.Size([1, 8, 8192, 128])
   ⚠️ Qwen 传来的 Mask 形状: torch.Size([1, 1, 8192, 8192])
   ⚠️ Mask 的数据类型: torch.bool
   💀 仅这个 Mask 矩阵就会占据显存: 0.125000 GB
   is_causal 参数: False

2.Verify gradient explosion

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 train_dapo_3_debug.py  --use_lora --use_4bit

Output:

……
💥 [爆点定位] 梯度在经过 【model.layers.30.linear_attn.norm】 的反向计算后瞬间爆炸!
   传出梯度最大值: 26240.0
⚠️ [数值爆炸预警] 层: model.layers.27.self_attn.v_proj | 梯度 Max: 17280.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.27.self_attn.k_norm | 梯度 Max: 135168.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.27.self_attn.k_proj | 梯度 Max: 111616.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.27.self_attn.q_norm | 梯度 Max: 126976.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.27.self_attn.q_proj | 梯度 Max: 97280.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn | 梯度 Max: 67072.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.o_proj | 梯度 Max: 67072.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.v_proj | 梯度 Max: 13172736.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.k_norm | 梯度 Max: 62390272.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.k_proj | 梯度 Max: 45613056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.q_norm | 梯度 Max: 297795584.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.23.self_attn.q_proj | 梯度 Max: 242221056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn | 梯度 Max: 132120576.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.o_proj | 梯度 Max: 132120576.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.v_proj | 梯度 Max: 48103633715200.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.k_norm | 梯度 Max: 68444598829056.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.k_proj | 梯度 Max: 54150947667968.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.q_norm | 梯度 Max: 141836999983104.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.19.self_attn.q_proj | 梯度 Max: 204509162766336.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn | 梯度 Max: 131391639519232.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.o_proj | 梯度 Max: 131391639519232.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.v_proj | 梯度 Max: 1549526502191602335744.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.k_norm | 梯度 Max: 1752440687002407403520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.k_proj | 梯度 Max: 2822351843277561397248.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.q_norm | 梯度 Max: 2600990914393046777856.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.15.self_attn.q_proj | 梯度 Max: 5902958103587056517120.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn | 梯度 Max: 2951479051793528258560.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.o_proj | 梯度 Max: 2951479051793528258560.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.v_proj | 梯度 Max: 4584246707978673830485819392.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.k_norm | 梯度 Max: 7969239002899635519663112192.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.k_proj | 梯度 Max: 14932651723879899565970685952.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.q_norm | 梯度 Max: 40852021296417549071671099392.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.11.self_attn.q_proj | 梯度 Max: 41470991316060239209120661504.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn | 梯度 Max: 52612451669628661683212779520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.o_proj | 梯度 Max: 52612451669628661683212779520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.v_proj | 梯度 Max: 7382797095729208034316799468109824.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.k_norm | 梯度 Max: 29206669829258405410484041851863040.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.k_proj | 梯度 Max: 50949412924372996104955495230472192.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.q_norm | 梯度 Max: 55168154121932543553136523497963520.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.7.self_attn.q_proj | 梯度 Max: 113581493780449354374104607201689600.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.3.self_attn | 梯度 Max: 138893940965806639063190776806637568.00 | Dtype: torch.bfloat16
⚠️ [数值爆炸预警] 层: model.layers.3.self_attn.o_proj | 梯度 Max: 138893940965806639063190776806637568.00 | Dtype: torch.bfloat16
……

3.Fix

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 train_dapo_3_debug.py  --use_lora --use_4bit  --use_flash_attn

Output:

……
[🔍 22:47:55] ✅ Step 1 梯度全部正常 (最大 =0.003845 @ base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight)
……

Expected behavior

  1. Avoid Silent Fallbacks to the Math Backend: When training Qwen3.5 models (or any Qwen2 architecture handling 3D position_ids/mRoPE), the transformers implementation explicitly materializes a massive 4D Dense Mask ([Batch, 1, SeqLen, SeqLen]) and sets is_causal=False. This design explicitly violates PyTorch SDPA’s fused kernel constraints (if (attn_mask.has_value()) { return false; }), silently forcing a downgrade to the Math backend. Expected behavior: The implementation should decouple mRoPE coordinate handling from the attention mask generation, preserving the ability to rely on the implicit is_causal=True mechanism, which keeps the highly optimized FlashAttention kernel engaged.

  2. Prevent BF16 Precision Collapse in Long-Context RLHF: The PyTorch Math backend is fundamentally unsafe for accumulating Softmax denominators over thousands of tokens (e.g., 8K - 100K) in bfloat16. Without the FP32 SRAM accumulators used by fusion kernels, the Math backend suffers from severe truncation errors (swamping). Under RLHF losses (like DPO/GRPO/DAPO) which contain exponential amplifiers (exp(beta * log_probs)), these errors invariably snowball into catastrophic $10^{28}$ or NaN gradients. Expected behavior: transformers should provide a native varlen (variable-length) or NestedTensors implementation for sdpa that physically truncates padded tokens rather than masking them with -3.4e38 in a dense tensor, thereby bypassing the mathematically unstable bfloat16 accumulations.

  3. Explicit Warning or Fallback to FA2: Until a native SDPA varlen solution is implemented, transformers should aggressively warn users when sdpa is initialized alongside padding masks on models requiring dense mask materialization. Currently, explicitly setting attn_implementation="flash_attention_2" is the only mathematically safe approach, as Qwen2FlashAttention2 uses cu_seqlens to physically drop padding and leverages FP32 registers internally, perfectly stabilizing the RLHF gradients.

extent analysis

Fix Plan

To address the gradient explosion issue, we need to modify the attention implementation to prevent silent fallbacks to the math backend and prevent BF16 precision collapse in long-context RLHF.

Here are the steps:

  • Decouple mRoPE coordinate handling from the attention mask generation to preserve the ability to rely on the implicit is_causal=True mechanism.
  • Implement a native varlen (variable-length) or NestedTensors implementation for sdpa that physically truncates padded tokens rather than masking them with -3.4e38 in a dense tensor.
  • Aggressively warn users when sdpa is initialized alongside padding masks on models requiring dense mask materialization.

Code Changes

# Set attn_implementation to "flash_attention_2" to use Qwen2FlashAttention2
model.config.attn_implementation = "flash_attention_2"

# Alternatively, implement a custom attention class that decouples mRoPE coordinate handling
# from the attention mask generation and uses a varlen implementation for sdpa
class CustomAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # Initialize Qwen2FlashAttention2 or a custom varlen implementation for sdpa

    def forward(self, query, key, value, attention_mask=None):
        # Implement custom attention logic that preserves the implicit is_causal=True mechanism
        # and uses a varlen implementation for sdpa to prevent BF16 precision collapse
        pass

Verification

To verify that the fix worked, run the training script with the modified attention implementation and check for the absence of gradient explosion warnings.

CUDA_LAUNCH_BLOCKING=1 CUDA_HOME=/usr/local/cuda-12.8 LD_PRELOAD=/opt/anaconda3/envs/scr_train2/lib/libstdc++.so.6 CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node=1 train_dapo_3_debug.py  --use_lora --use_4bit  --use_flash_attn

Check the output for the absence of ⚠️ [数值爆炸预警] warnings and verify that the gradients are stable.

Extra Tips

  • Always set attn_implementation to "flash_attention_2" when using Qwen2 models to ensure mathematically safe attention computations.
  • Consider implementing a custom attention class that decouples mRoPE coordinate handling from the attention mask generation and uses a varlen implementation for sdpa to prevent BF16 precision collapse.
  • Aggressively monitor gradient stability during training and adjust the attention implementation as needed to prevent gradient explosion.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

  1. Avoid Silent Fallbacks to the Math Backend: When training Qwen3.5 models (or any Qwen2 architecture handling 3D position_ids/mRoPE), the transformers implementation explicitly materializes a massive 4D Dense Mask ([Batch, 1, SeqLen, SeqLen]) and sets is_causal=False. This design explicitly violates PyTorch SDPA’s fused kernel constraints (if (attn_mask.has_value()) { return false; }), silently forcing a downgrade to the Math backend. Expected behavior: The implementation should decouple mRoPE coordinate handling from the attention mask generation, preserving the ability to rely on the implicit is_causal=True mechanism, which keeps the highly optimized FlashAttention kernel engaged.

  2. Prevent BF16 Precision Collapse in Long-Context RLHF: The PyTorch Math backend is fundamentally unsafe for accumulating Softmax denominators over thousands of tokens (e.g., 8K - 100K) in bfloat16. Without the FP32 SRAM accumulators used by fusion kernels, the Math backend suffers from severe truncation errors (swamping). Under RLHF losses (like DPO/GRPO/DAPO) which contain exponential amplifiers (exp(beta * log_probs)), these errors invariably snowball into catastrophic $10^{28}$ or NaN gradients. Expected behavior: transformers should provide a native varlen (variable-length) or NestedTensors implementation for sdpa that physically truncates padded tokens rather than masking them with -3.4e38 in a dense tensor, thereby bypassing the mathematically unstable bfloat16 accumulations.

  3. Explicit Warning or Fallback to FA2: Until a native SDPA varlen solution is implemented, transformers should aggressively warn users when sdpa is initialized alongside padding masks on models requiring dense mask materialization. Currently, explicitly setting attn_implementation="flash_attention_2" is the only mathematically safe approach, as Qwen2FlashAttention2 uses cu_seqlens to physically drop padding and leverages FP32 registers internally, perfectly stabilizing the RLHF gradients.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING