transformers - ✅(Solved) Fix Double softmax in MoE router load-balancing loss (mixtral, qwen2_moe, qwen3_vl_moe families) [2 pull requests, 1 comments, 2 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#45120Fetched 2026-04-08 01:52:44
View on GitHub
Comments
1
Participants
2
Timeline
8
Reactions
0
Timeline (top)
cross-referenced ×2referenced ×2closed ×1commented ×1

Root Cause

Taking Qwen3_5MoeTopKRouter as an example (modeling_qwen3_5_moe.py):

def forward(self, hidden_states):
    router_logits = F.linear(hidden_states, self.weight)          # raw logits
    router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)  # now probabilities!
    router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
    ...
    return router_logits, router_scores, router_indices  # returns probs as "logits"

Then in load_balancing_loss_func:

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)  # softmax applied AGAIN

Note: the SparseMoeBlock.forward() discards the first return value (_), so only OutputRecorderload_balancing_loss_func consumes it. No downstream routing logic is affected.

Fix Action

Fixed

PR fix notes

PR #45131: Fix MoE routers returning probabilities instead of logits

Description (problem / solution / changelog)

What does this PR do?

Fixes issue #45120: Several MoE routers returned softmaxed probabilities as router_logits, which caused load_balancing_loss_func to compute softmax(softmax(logits)), flattening routing distributions and weakening gradient signals during fine-tuning. This PR fixes it by keeping router_logits as raw logits and computing router_probs separately for top-k routing.


Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by code agents. We are currently bottlenecked by our ability to review and respond to them. As a result, we ask that new users do not submit pure code agent PRs at this time.

  • I confirm that this is not a pure code agent PR.

For more information, please read CONTRIBUTING.md.


Before submitting

  • This PR fixes the issue and does not break inference
  • Did you read the contributor guideline?
  • Documentation does not need changes, as behavior/API is unchanged
  • No new tests are required; logic change only affects auxiliary loss

Who can review?

Anyone in the community is free to review the PR once the tests have passed.
Recommended reviewers:

  • Models: @ArthurZucker @Cyrilvallez
  • MoE / large model routing: @SunMarc @Rocketknight1

Other reviewers for reference:

  • text models: @ArthurZucker @Cyrilvallez
  • vision models: @yonigozlan @molbap
  • audio models: @eustlb @ebezzam @vasqu
  • multimodal: @zucchini-nlp
  • graph: @clefourrier
  • library: @gante, @zucchini-nlp
  • trainer: @SunMarc
  • attention: @vasqu, @ArthurZucker, @CyrilVallez
  • distributed: @3outeille, @ArthurZucker
  • CI: @ydshieh
  • devices/backends: @ivarflakstad, @IlyasMoutawwakil
  • documentation: @stevhliu

How to test / verify

  1. Train a model using output_router_logits=True.
  2. Check that router_logits returned by routers are raw logits, not probabilities.
  3. Verify that load_balancing_loss_func receives proper logits and computes meaningful gradients.
  4. Confirm that inference behavior is unchanged.

Changed files

  • src/transformers/models/mixtral/modular_mixtral.py (modified, +2/-2)
  • src/transformers/models/qwen2_moe/modular_qwen2_moe.py (modified, +2/-2)
  • src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py (modified, +2/-2)

PR #45132: Fix: Remove double softmax in MoE router load-balancing loss (Mixtral, Qwen2MoE, Qwen3VLMoE)

Description (problem / solution / changelog)

Summary

This PR fixes GitHub issue #45120: "Double softmax in MoE router load-balancing loss". MoE routers in Mixtral, Qwen2MoE, and Qwen3VLMoE were applying softmax inside forward(), then the load_balancing_loss_func applied softmax AGAIN, resulting in softmax(softmax(logits)) which flattened routing probabilities.

Root Cause

Three routers reassigned router_logits with softmaxed values, then returned them to load_balancing_loss_func which expected raw logits.

Solution

Separated concepts: keep raw logits in router_logits, use router_probs for softmaxed values during routing.

Changes

  • MixtralTopKRouter: Renamed softmax reassignment to router_probs
  • Qwen2MoeTopKRouter: Renamed softmax reassignment to router_probs
  • Qwen3VLMoeTextTopKRouter: Renamed softmax reassignment to router_probs
  • Added comprehensive unit tests verifying router_logits are raw logits

Impact

  • Fine-tuning: Load-balancing loss now receives correct raw logits
  • Inference: No changes
  • Backward compatible: Only internal computation affected

Fixes #45120

Changed files

  • src/transformers/models/mixtral/modeling_mixtral.py (modified, +2/-2)
  • src/transformers/models/mixtral/modular_mixtral.py (modified, +2/-2)
  • src/transformers/models/qwen2_moe/modeling_qwen2_moe.py (modified, +2/-2)
  • src/transformers/models/qwen2_moe/modular_qwen2_moe.py (modified, +2/-2)
  • src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py (modified, +2/-2)
  • src/transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py (modified, +2/-2)
  • tests/models/mixtral/test_modeling_mixtral.py (modified, +84/-22)
  • tests/models/qwen2_moe/test_modeling_qwen2_moe.py (modified, +59/-0)
  • tests/models/qwen3_vl_moe/test_modeling_qwen3_vl_moe.py (modified, +60/-0)

Code Example

import torch
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeTextConfig
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM

config = Qwen3_5MoeTextConfig(
    vocab_size=1000, hidden_size=128, num_hidden_layers=4,
    num_attention_heads=4, num_key_value_heads=2, head_dim=32,
    moe_intermediate_size=64, shared_expert_intermediate_size=64,
    num_experts=8, num_experts_per_tok=2,
    linear_num_key_heads=4, linear_num_value_heads=4,
    linear_key_head_dim=32, linear_value_head_dim=32,
    linear_conv_kernel_dim=4, output_router_logits=True,
    router_aux_loss_coef=0.001, max_position_embeddings=128,
)

model = Qwen3_5MoeForCausalLM(config)
model.train()

input_ids = torch.randint(0, 1000, (1, 16))
outputs = model(input_ids, labels=input_ids.clone(), output_router_logits=True)

# router_logits are already probabilities (sum to 1.0 per row)
gate_probs = outputs.router_logits[0]
print("Row sums (should NOT be ~1.0 if these were raw logits):")
print(gate_probs.sum(dim=-1)[:5])
# Output: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], ...)

---

def forward(self, hidden_states):
    router_logits = F.linear(hidden_states, self.weight)          # raw logits
    router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)  # now probabilities!
    router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
    ...
    return router_logits, router_scores, router_indices  # returns probs as "logits"

---

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)  # softmax applied AGAIN

---

router_logits = F.linear(hidden_states, self.weight)
router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1)
...
return router_logits, router_scores, router_indices  # now returns actual logits
RAW_BUFFERClick to expand / collapse

Bug description

Several MoE routers apply softmax to raw logits inside their forward() method, then return the result as the first value (router_logits). This value is captured by OutputRecorder and passed to load_balancing_loss_func, which applies softmax again — computing the auxiliary loss on softmax(softmax(logits)).

This flattens the routing probability distribution toward uniform (1/num_experts), making router_prob_per_expert nearly constant regardless of actual routing decisions. The load-balancing loss effectively provides no useful gradient signal.

Reproduction

import torch
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import Qwen3_5MoeTextConfig
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeForCausalLM

config = Qwen3_5MoeTextConfig(
    vocab_size=1000, hidden_size=128, num_hidden_layers=4,
    num_attention_heads=4, num_key_value_heads=2, head_dim=32,
    moe_intermediate_size=64, shared_expert_intermediate_size=64,
    num_experts=8, num_experts_per_tok=2,
    linear_num_key_heads=4, linear_num_value_heads=4,
    linear_key_head_dim=32, linear_value_head_dim=32,
    linear_conv_kernel_dim=4, output_router_logits=True,
    router_aux_loss_coef=0.001, max_position_embeddings=128,
)

model = Qwen3_5MoeForCausalLM(config)
model.train()

input_ids = torch.randint(0, 1000, (1, 16))
outputs = model(input_ids, labels=input_ids.clone(), output_router_logits=True)

# router_logits are already probabilities (sum to 1.0 per row)
gate_probs = outputs.router_logits[0]
print("Row sums (should NOT be ~1.0 if these were raw logits):")
print(gate_probs.sum(dim=-1)[:5])
# Output: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], ...)

Root cause

Taking Qwen3_5MoeTopKRouter as an example (modeling_qwen3_5_moe.py):

def forward(self, hidden_states):
    router_logits = F.linear(hidden_states, self.weight)          # raw logits
    router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)  # now probabilities!
    router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
    ...
    return router_logits, router_scores, router_indices  # returns probs as "logits"

Then in load_balancing_loss_func:

routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)  # softmax applied AGAIN

Note: the SparseMoeBlock.forward() discards the first return value (_), so only OutputRecorderload_balancing_loss_func consumes it. No downstream routing logic is affected.

Affected models

Three source routers (in modular files):

  • MixtralTopKRouter in mixtral/modular_mixtral.py
  • Qwen2MoeTopKRouter in qwen2_moe/modular_qwen2_moe.py
  • Qwen3VLMoeTextTopKRouter in qwen3_vl_moe/modular_qwen3_vl_moe.py

Downstream models inheriting these routers: minimax, olmoe, flex_olmo, qwen3_moe, qwen3_next, qwen3_omni_moe, qwen3_5_moe

Suggested fix

Use a separate variable for the softmaxed values, keeping router_logits as raw logits:

router_logits = F.linear(hidden_states, self.weight)
router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1)
...
return router_logits, router_scores, router_indices  # now returns actual logits

Impact

This only affects fine-tuning with output_router_logits=True. It does not affect inference or pretrained model quality (those were trained with the original upstream codebases, not HuggingFace transformers).

extent analysis

Fix Plan

To fix the issue, we need to modify the forward method of the affected routers to return the raw logits instead of the softmaxed probabilities. Here are the steps:

  • Modify the forward method of Qwen3_5MoeTopKRouter (and other affected routers) to use a separate variable for the softmaxed values:
def forward(self, hidden_states):
    router_logits = F.linear(hidden_states, self.weight)          # raw logits
    router_probs = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)  # now probabilities!
    router_top_value, router_indices = torch.topk(router_probs, self.top_k, dim=-1)
    ...
    return router_logits, router_scores, router_indices  # now returns actual logits
  • Update the MixtralTopKRouter, Qwen2MoeTopKRouter, and Qwen3VLMoeTextTopKRouter classes in their respective files to reflect the same change.

Verification

To verify that the fix worked, you can run the following code:

config = Qwen3_5MoeTextConfig(
    vocab_size=1000, hidden_size=128, num_hidden_layers=4,
    num_attention_heads=4, num_key_value_heads=2, head_dim=32,
    moe_intermediate_size=64, shared_expert_intermediate_size=64,
    num_experts=8, num_experts_per_tok=2,
    linear_num_key_heads=4, linear_num_value_heads=4,
    linear_key_head_dim=32, linear_value_head_dim=32,
    linear_conv_kernel_dim=4, output_router_logits=True,
    router_aux_loss_coef=0.001, max_position_embeddings=128,
)

model = Qwen3_5MoeForCausalLM(config)
model.train()

input_ids = torch.randint(0, 1000, (1, 16))
outputs = model(input_ids, labels=input_ids.clone(), output_router_logits=True)

# router_logits are now raw logits (do NOT sum to 1.0 per row)
gate_logits = outputs.router_logits[0]
print("Row sums (should NOT be ~1.0 if these are raw logits):")
print(gate_logits.sum(dim=-1)[:5])

The output should show that the row sums are not close to 1.0, indicating that the router_logits are now raw logits.

Extra Tips

  • Make sure to update all affected models and routers to reflect the change.
  • Test the fix thoroughly to ensure that it does not introduce any new issues.
  • Consider adding a test case to verify that the router_logits are indeed raw logits.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix Double softmax in MoE router load-balancing loss (mixtral, qwen2_moe, qwen3_vl_moe families) [2 pull requests, 1 comments, 2 participants]