pytorch - 💡(How to fix) Fix [RFC] AdamTR: Adam variant for deterministic Token-Routed architectures [1 participants]

Official PRs (…)
ON THIS PAGE

Recommended Tools

×6

Utilities matched from this issue’s tags and category — try them while you read without losing context.

GitHub issue graph ai analysis

Paste a GitHub issue URL. We fetch that issue, discover linked issues from bodies/comments/timeline, collect linked pull requests, and produce a structured English report.

The report is written in English Markdown for sharing and archival.

Helpful · Quick feedback

Loading…
GitHub stats
pytorch/pytorch#179143Fetched 2026-04-08 02:33:04
View on GitHub
Comments
0
Participants
1
Timeline
14
Reactions
0
Participants
Timeline (top)
mentioned ×5subscribed ×5labeled ×3cross-referenced ×1

Error Message

  • Using separate parameter groups in AdamW with different LR/weight_decay per group. This is verbose, error-prone, and doesn't support gradient scaling by token distribution.
RAW_BUFFERClick to expand / collapse

🚀 The feature, motivation and pitch

Motivation Standard AdamW treats all parameters identically, which is suboptimal for Mixture-of-Experts architectures where routed expert parameters, shared expert parameters, and dense (attention/embedding) parameters have fundamentally different training dynamics.

In deterministic Token-Routed MoE (under review at TMLR), we observe a loss gap vs dense baselines that converges from +0.31 to +0.09 with AdamW but plateaus. We hypothesize that an MoE-aware optimizer can close this gap further.

Proposal: AdamTR (Adam Token-Routed) AdamTR extends AdamW with four features:

Per-expert learning rate scaling — experts that see fewer tokens (due to Zipf routing) get higher LR for faster specialization Expert-aware weight decay — lighter decay for routed experts to preserve learned specialization, standard decay for shared/dense params Gradient normalization by token count — prevents high-frequency experts from dominating updates by normalizing gradients by relative token load Separate momentum per expert group — natural with [E, H, I] parameter tensors, prevents cross-expert momentum contamination Implementation Draft implementation: https://github.com/Complexity-ML/complexity-framework/blob/main/complexity/training/adam_tr.py

Core algorithm is identical to AdamW — the four features are applied as pre-processing before the standard Adam update step. ~230 lines, no new dependencies.

Questions for the community Would this be appropriate as a PR to torch.optim? What benchmarks would be expected? (We plan to run 384M MoE vs dense, 8B tokens) Should this target PyTorch core or a separate package like torch-optimizer? References Architecture paper: https://openreview.net/forum?id=jZq6EVboC6 (TMLR, under review) Framework: https://github.com/Complexity-ML/complexity-framework

Alternatives

  • Using separate parameter groups in AdamW with different LR/weight_decay per group. This is verbose, error-prone, and doesn't support gradient scaling by token distribution.
  • No existing optimizer in torch.optim addresses the asymmetry between expert and dense parameters in MoE architectures.

Additional context

Training results at 384M scale (8B tokens, 4 experts, iso-params comparison):

  • Loss gap with AdamW: +0.31 (step 400) → +0.09 (step 5000+), plateaus
  • MoE achieves 8,078 tok/s inference (3x faster than dense) but trails in loss
  • Architecture: deterministic routing (no learned router), CUDA graph compatible
  • PyTorch 2.11, FSDP full_shard, bf16, 2× RTX PRO 6000

extent analysis

TL;DR

Implementing AdamTR, an MoE-aware optimizer, may help close the loss gap between Mixture-of-Experts architectures and dense baselines by adapting to the different training dynamics of expert and dense parameters.

Guidance

  • Review the draft implementation of AdamTR in adam_tr.py to understand how the four features (per-expert learning rate scaling, expert-aware weight decay, gradient normalization by token count, and separate momentum per expert group) are applied as pre-processing before the standard Adam update step.
  • Consider benchmarking AdamTR against AdamW using the planned 384M MoE vs dense setup with 8B tokens to evaluate its effectiveness in closing the loss gap.
  • Evaluate whether targeting PyTorch core or a separate package like torch-optimizer is more suitable for the AdamTR implementation, considering factors such as maintainability and community adoption.
  • Investigate how the separate momentum per expert group feature prevents cross-expert momentum contamination and its impact on training dynamics.

Example

No explicit code example is provided, but the draft implementation of AdamTR can be found in the referenced GitHub repository.

Notes

The effectiveness of AdamTR in closing the loss gap between MoE architectures and dense baselines is still under investigation, and further experimentation and benchmarking are needed to confirm its benefits.

Recommendation

Apply the AdamTR workaround to adapt the optimizer to the specific training dynamics of Mixture-of-Experts architectures, as it has the potential to improve training results and close the loss gap with dense baselines.

Vote matrix · Quick signals

Works
Did the solution work? Tap to confirm.
Easy Fix
Was it a quick fix?
Time Saver
Did it save you time?
Blocking
Was it severely blocking?
Common Issue
Are others likely hitting this too?
Flaky / Intermittent
Is it intermittent?
Verified / Reproducible
Can you reproduce it reliably?
Loading…

Still need to ship something?

×6

Another batch ranked right after the header list — different links, same matching logic.

Back to top recommendations

TRENDING