transformers - ✅(Solved) Fix Load balancing loss not added when output_router_logits=False [2 pull requests, 9 comments, 4 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
huggingface/transformers#44242Fetched 2026-04-08 00:29:39
View on GitHub
Comments
9
Participants
4
Timeline
22
Reactions
0
Author
Timeline (top)
commented ×9subscribed ×5mentioned ×4cross-referenced ×2

Error Message

print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}") ^^^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'item'

Fix Action

Fixed

PR fix notes

PR #44264: [Moe] Enable aux loss automatically when in training + coef is not 0

Description (problem / solution / changelog)

As per title, WIP --> needs a test

Changed files

  • src/transformers/models/dbrx/modeling_dbrx.py (modified, +2/-7)
  • src/transformers/models/dbrx/modular_dbrx.py (modified, +2/-7)
  • src/transformers/models/doge/modeling_doge.py (modified, +2/-6)
  • src/transformers/models/doge/modular_doge.py (modified, +2/-6)
  • src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py (modified, +2/-8)
  • src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py (modified, +1/-7)
  • src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py (modified, +1/-7)
  • src/transformers/models/flex_olmo/modeling_flex_olmo.py (modified, +2/-8)
  • src/transformers/models/glm4v_moe/configuration_glm4v_moe.py (modified, +5/-0)
  • src/transformers/models/glm4v_moe/modular_glm4v_moe.py (modified, +5/-0)
  • src/transformers/models/gpt_oss/configuration_gpt_oss.py (modified, +1/-1)
  • src/transformers/models/gpt_oss/modeling_gpt_oss.py (modified, +2/-8)
  • src/transformers/models/granitemoe/modeling_granitemoe.py (modified, +2/-5)
  • src/transformers/models/granitemoe/modular_granitemoe.py (modified, +2/-5)
  • src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py (modified, +2/-5)
  • src/transformers/models/granitemoeshared/modeling_granitemoeshared.py (modified, +2/-5)
  • src/transformers/models/jamba/modeling_jamba.py (modified, +2/-7)
  • src/transformers/models/minimax/modeling_minimax.py (modified, +2/-8)
  • src/transformers/models/minimax_m2/modeling_minimax_m2.py (modified, +2/-8)
  • src/transformers/models/mixtral/modeling_mixtral.py (modified, +2/-8)
  • src/transformers/models/mixtral/modular_mixtral.py (modified, +2/-8)
  • src/transformers/models/nllb_moe/modeling_nllb_moe.py (modified, +3/-7)
  • src/transformers/models/olmoe/modeling_olmoe.py (modified, +2/-8)
  • src/transformers/models/phimoe/modeling_phimoe.py (modified, +2/-8)
  • src/transformers/models/qwen2_moe/modeling_qwen2_moe.py (modified, +2/-8)
  • src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py (modified, +2/-7)
  • src/transformers/models/qwen3_moe/modeling_qwen3_moe.py (modified, +2/-8)
  • src/transformers/models/qwen3_moe/modular_qwen3_moe.py (modified, +2/-8)
  • src/transformers/models/qwen3_next/modeling_qwen3_next.py (modified, +2/-7)
  • src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py (modified, +2/-7)
  • src/transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py (modified, +2/-7)
  • src/transformers/models/switch_transformers/configuration_switch_transformers.py (modified, +5/-0)
  • src/transformers/models/switch_transformers/modeling_switch_transformers.py (modified, +2/-5)
  • src/transformers/models/switch_transformers/modular_switch_transformers.py (modified, +2/-5)
  • src/transformers/utils/generic.py (modified, +13/-0)
  • utils/check_config_attributes.py (modified, +2/-1)

PR #44586: Fix Mixtral aux_loss not computed when output_router_logits=False

Description (problem / solution / changelog)

What does this PR do?

Decouples router logits collection from output visibility in Mixtral's ForCausalLM. Previously, output_router_logits=False (the default) prevented aux_loss from being computed, meaning load balancing was silently disabled during training even when router_aux_loss_coef > 0.

The fix:

  • Always collect router logits internally when router_aux_loss_coef > 0
  • Always compute aux_loss when router logits are available
  • Only include router_logits in the model output when the user explicitly sets output_router_logits=True

This affects all MoE models inheriting from Mixtral via modular conversion: ernie4_5_moe, flex_olmo, gpt_oss, jamba, minimax, minimax_m2, olmoe, phimoe, qwen2_moe, qwen3_5_moe, qwen3_next.

Fixes #44242

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case. https://github.com/huggingface/transformers/issues/44242
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@SunMarc @ArthurZucker @Cyrilvallez (MoE models, training)

This contribution was developed with AI assistance (Claude Code).

Changed files

  • src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py (modified, +7/-3)
  • src/transformers/models/flex_olmo/modeling_flex_olmo.py (modified, +7/-3)
  • src/transformers/models/gpt_oss/modeling_gpt_oss.py (modified, +7/-3)
  • src/transformers/models/jamba/modeling_jamba.py (modified, +7/-3)
  • src/transformers/models/minimax/modeling_minimax.py (modified, +7/-3)
  • src/transformers/models/minimax_m2/modeling_minimax_m2.py (modified, +7/-3)
  • src/transformers/models/mixtral/modeling_mixtral.py (modified, +7/-3)
  • src/transformers/models/mixtral/modular_mixtral.py (modified, +7/-3)
  • src/transformers/models/olmoe/modeling_olmoe.py (modified, +7/-3)
  • src/transformers/models/phimoe/modeling_phimoe.py (modified, +7/-3)
  • src/transformers/models/qwen2_moe/modeling_qwen2_moe.py (modified, +7/-3)
  • src/transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py (modified, +7/-3)
  • src/transformers/models/qwen3_next/modeling_qwen3_next.py (modified, +7/-3)

Code Example

from transformers import MixtralConfig, MixtralForCausalLM
import torch

# 1. Configure the model to output router logits
config = MixtralConfig(
    vocab_size=32000,
    hidden_size=2048,
    num_hidden_layers=2, # small for demonstration
    num_local_experts=8,
    output_router_logits=False,  
    router_aux_loss_coef=0.001   # The scaling factor for the load balancing loss
)

# 2. Initialize the model
model = MixtralForCausalLM(config)

# 3. Create dummy inputs and labels for a training step
input_ids = torch.tensor([[1, 254, 99, 32]])
labels = torch.tensor([[1, 254, 99, 32]]) # Next-token prediction labels

# 4. Perform the forward pass
outputs = model(input_ids=input_ids, labels=labels)

# 5. Read the losses
total_loss = outputs.loss
aux_loss = outputs.aux_loss 
router_logits = outputs.router_logits # Tuple of router logits for each layer

print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}")
print(f"Total Loss (Cross Entropy + {config.router_aux_loss_coef} * Aux Loss): {total_loss.item():.4f}")

---

print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}")
                                            ^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'item'
RAW_BUFFERClick to expand / collapse

System Info

version:4.57.3 In file models/mixtral/modelling_mixtral.py, the aux_loss is not computed and added to the overall loss, when output_router_logits=False in the MixtralConfig.

This is not intended, since according to the documentation https://huggingface.co/docs/transformers/en/model_doc/mixtral, the auxillary loss should be added as long as router_aux_loss_coef != 0. Thus by default output_router_logits=False, the model is not by default doing load balancing even when we set a nonzero router_aux_loss_coef.

Who can help?

@SunMarc @ArthurZucker @Cyrilvallez

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import MixtralConfig, MixtralForCausalLM
import torch

# 1. Configure the model to output router logits
config = MixtralConfig(
    vocab_size=32000,
    hidden_size=2048,
    num_hidden_layers=2, # small for demonstration
    num_local_experts=8,
    output_router_logits=False,  
    router_aux_loss_coef=0.001   # The scaling factor for the load balancing loss
)

# 2. Initialize the model
model = MixtralForCausalLM(config)

# 3. Create dummy inputs and labels for a training step
input_ids = torch.tensor([[1, 254, 99, 32]])
labels = torch.tensor([[1, 254, 99, 32]]) # Next-token prediction labels

# 4. Perform the forward pass
outputs = model(input_ids=input_ids, labels=labels)

# 5. Read the losses
total_loss = outputs.loss
aux_loss = outputs.aux_loss 
router_logits = outputs.router_logits # Tuple of router logits for each layer

print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}")
print(f"Total Loss (Cross Entropy + {config.router_aux_loss_coef} * Aux Loss): {total_loss.item():.4f}")

Expected behavior

    print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}")
                                            ^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'item'

It can output the loss only when output_router_logits=True.

extent analysis

Fix Plan

1. Update MixtralConfig to compute aux_loss when output_router_logits=False

In models/mixtral/modelling_mixtral.py, update the forward method to compute aux_loss regardless of output_router_logits value.

class MixtralForCausalLM(MixtralPreTrainedModel):
    # ...

    def forward(self, input_ids, attention_mask, labels, **kwargs):
        # ...
        aux_loss = self.router.get_aux_loss()  # Compute aux loss
        # ...
        if self.config.output_router_logits:
            outputs = (router_logits, aux_loss, loss)
        else:
            outputs = (loss, aux_loss)  # Return aux loss even when output_router_logits=False
        # ...

2. Update MixtralConfig to include aux_loss in the overall loss when output_router_logits=False

In models/mixtral/modelling_mixtral.py, update the forward method to include aux_loss in the overall loss when output_router_logits=False.

class MixtralForCausalLM(MixtralPreTrainedModel):
    # ...

    def forward(self, input_ids, attention_mask, labels, **kwargs):
        # ...
        loss = self.router.get_loss()  # Compute loss
        aux_loss = self.router.get_aux_loss()  # Compute aux loss
        if self.config.output_router_logits:
            outputs = (router_logits, loss, aux_loss)
        else:
            outputs = (loss + self.config.router_aux_loss_coef * aux_loss, aux_loss)  # Include aux loss in overall loss
        # ...

3. Update MixtralConfig documentation to reflect the change

Update the documentation to reflect that aux_loss is computed and included in the

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

FAQ

Expected behavior

    print(f"Auxiliary Load Balancing Loss: {aux_loss.item():.4f}")
                                            ^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'item'

It can output the loss only when output_router_logits=True.

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING

transformers - ✅(Solved) Fix Load balancing loss not added when output_router_logits=False [2 pull requests, 9 comments, 4 participants]