Skip to content

Add MoE load balancing loss to distillation#3679

Open
JamesDeng42 wants to merge 1 commit intomainfrom
yujiedeng/load-balance-loss
Open

Add MoE load balancing loss to distillation#3679
JamesDeng42 wants to merge 1 commit intomainfrom
yujiedeng/load-balance-loss

Conversation

@JamesDeng42
Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 commented Apr 16, 2026

Description

This PR introduces support for Mixture of Experts (MoE) load balancing loss during the distillation workflow.

Key Changes

  1. NNX Intermediate Extraction (maxtext_utils.py & qwen3.py):
    • Replaced legacy Linen self.sow(...) calls with native nnx.Intermediate(load_balance_loss) inside
  2. Distillation Strategy Updates (distillation_utils.py & train_distill.py):
    • Upgraded DistillationForwardOutput to carry the collected moe_lb_loss.
    • Updated CombinedDistillationStrategy to actively add the moe_lb_loss to the total_loss so the optimizer
      minimizes it.
    • Surfaced "distill/moe_lb_loss" to the metrics dictionary for TensorBoard logging and visibility.
  3. Model Mutability (models.py):
    • Automatically appended "intermediates" to the mutable_collections list during the Transformer's forward pass
      whenever load_balance_loss_weight > 0.0 to ensure NNX variables can successfully write to the state.

Tests

Added "distill/moe_lb_loss" to the expected metrics keys in the test suite to prevent regressions in train_distill_test.py.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 16, 2026

@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch 2 times, most recently from 69341c6 to 414e2f0 Compare April 16, 2026 00:24
Comment thread src/maxtext/utils/maxtext_utils.py Outdated
Comment thread src/maxtext/models/models.py
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Comment thread src/maxtext/utils/maxtext_utils.py Outdated
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py
@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch from 414e2f0 to 9628440 Compare April 21, 2026 18:22
Copy link
Copy Markdown
Collaborator

@vlad-karp vlad-karp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please address the comment

@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch 3 times, most recently from ae5c4cf to 74012ee Compare April 21, 2026 22:35
Comment thread src/maxtext/trainers/post_train/distillation/distillation_utils.py Outdated
@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch from 74012ee to 4d546e3 Compare April 21, 2026 23:56
Comment thread src/maxtext/models/qwen3.py
Comment thread src/maxtext/trainers/post_train/distillation/train_distill.py Outdated
@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch 3 times, most recently from ccd17dc to 2a27030 Compare April 22, 2026 18:47
@JamesDeng42 JamesDeng42 force-pushed the yujiedeng/load-balance-loss branch from 2a27030 to 2833dab Compare April 22, 2026 23:43
Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants