Skip to content

distillation: resume + xpk launcher + metrics refactor#3701

Merged
copybara-service[bot] merged 1 commit intomainfrom
agagik-checkpoint
Apr 22, 2026
Merged

distillation: resume + xpk launcher + metrics refactor#3701
copybara-service[bot] merged 1 commit intomainfrom
agagik-checkpoint

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Apr 20, 2026

Description

Distillation trainer: token-weighted metrics, use Tunix checkpointing, XPK launcher.

  • Metrics: token-weighted (sum, count) contract. The old path emitted
    scalars aggregated via np.mean, which biased logs under uneven masks /
    multi-host shards. New weighted_mean aggregator computes
    sum(sums) / sum(counts) — unbiased. Adds student_perplexity,
    teacher_perplexity, kl_div_T1. Renames distill/kl_div
    distill/kl_div_at_T. Clamps CE before exp() so divergences don't
    poison PPL averages.
  • Checkpointing: delegate to upstream Tunix. maybe_restore now just
    unwraps ModelBundle → student_model; the local _shard_optimizer
    override is gone (upstream handles replicated scalars via
    with_sharding_constraint).
  • New: scripts/run_distill_xpk.sh — reference XPK launcher with
    prep_image (layers Tunix branch + repins jax/libtpu), submit,
    monitor, resume_until_done.

Tests

  • New unit tests (tests/post_training/unit/distillation_metrics_test.py,
    cpu_only, 14 tests): KL direction + self-equivalence, CE/PPL numerics,
    pad masking, T² scaling, T=1 KL invariance, aggregator single-host +
    4-device sharded, feature-mapping (β>0) layer-indexing, eval contract.
  • Existing tests (train_distill_test.py): updated for the renamed
    metric and the (sum, count) tuple. 22/22 green.
  • E2E on tpu7x-4x4x4: submit to step 150 (checkpoints 50/100/150,
    loss 1.53 → 1.20), then resume to step 250 — restored at 150, no loss
    regression.

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 20, 2026

Codecov Report

❌ Patch coverage is 98.36066% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...ners/post_train/distillation/distillation_utils.py 98.30% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR significantly improves the robustness and observability of the distillation trainer. The transition to a token-weighted metrics contract (sum, count) ensures unbiased logging in multi-host environments, and the delegation of checkpointing to Tunix (while preserving MaxText's Grain handling) simplifies the codebase.

🔍 General Feedback

  • Numerical Stability: The addition of _PPL_CE_CAP is a great best practice for preventing inf in logged perplexity metrics.
  • Testing: The new unit tests in distillation_metrics_test.py are comprehensive, covering KL direction, self-equivalence, and sharded aggregation.
  • Launcher: The new XPK script is well-documented and provides a clear path for users to run distillation at scale.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔍 Detailed Feedback

  • Numerical stability cap for PPL calculation.
  • Masking logic verification.
  • XPK launcher script suggestions.
  • State synchronization after restore.

# Distillation Strategy
# -----------------------------------------------------------------------------

# Clamp CE before exp() so a divergence spike doesn't poison PPL averages
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 This cap is a sensible addition for numerical stability. 20 nats (~4.8e8 PPL) is a safe upper bound that prevents a single divergent token from poisoning the entire window's average with inf while remaining well below the fp32 exp overflow limit (~88).

# Per-token validity mask, derived from the one-hot labels so we don't need
# a separate mask input. A padded (fully-zero) row yields `any != 0 == False`.
mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32) # [B, T]
valid_count = jnp.sum(mask)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using jnp.any(labels != 0, axis=-1) works correctly for one-hot labels to identify valid tokens. However, since the DistillationStrategy class already tracks self.pad_id, it might be more robust to explicitly use that if the label format ever changes from one-hot in the future. Given the current design, this is correct.

Comment thread src/maxtext/trainers/post_train/distillation/scripts/run_distill_xpk.sh Outdated
Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔍 Detailed Feedback (Part 2)

  • Corrected line numbers for previous comments.
  • Verification of state synchronization and numerical stability measures.

# Clamp CE before exp() so a divergence spike doesn't poison PPL averages
# with inf. 20 nats is well above plausible CE (Llama random-init ~11.76)
# and far below fp32 exp overflow (~88).
_PPL_CE_CAP = 20.0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 This cap is a sensible addition for numerical stability. 20 nats (~4.8e8 PPL) is a safe upper bound that prevents a single divergent token from poisoning the entire window's average with inf while remaining well below the fp32 exp overflow limit (~88).

feature_loss = jnp.array(0.0)
# Per-token validity mask, derived from the one-hot labels so we don't need
# a separate mask input. A padded (fully-zero) row yields `any != 0 == False`.
mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32) # [B, T]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using jnp.any(labels != 0, axis=-1) works correctly for one-hot labels to identify valid tokens. However, since the DistillationStrategy class already tracks self.pad_id, it might be more robust to explicitly use that if the label format ever changes from one-hot in the future. Given the current design, this is correct.

RUN pip install --no-cache-dir --force-reinstall --no-deps \\
"jax==$JAX_PIN" "jaxlib==$JAXLIB_PIN" "libtpu==$LIBTPU_PIN"
EOF
sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 This provides a great starting point for users. Overwriting the local $XPK_BASE_IMAGE tag is efficient for local development, but adding a note that it modifies the local image state (which you've done in the comments) is good practice.

Suggested change
sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp"
sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp"

@gagika gagika force-pushed the agagik-checkpoint branch from d5da820 to 2020232 Compare April 20, 2026 06:19
@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant improvements to the MaxText distillation trainer, primarily focusing on metric aggregation and checkpointing. The shift to a token-weighted (sum, count) aggregation contract ensures unbiased metrics across multi-host setups, and the delegation to upstream Tunix for checkpointing simplifies the maintenance of restoration logic.

🔍 General Feedback

  • Robust Metrics: The refactor to (sum, count) pairs and the introduction of weighted_mean is a great architectural choice for distributed training stability.
  • Improved Monitoring: The addition of kl_div_T1 and token-weighted perplexity metrics significantly improves the comparability of runs across different annealing schedules.
  • XPK Integration: The new run_distill_xpk.sh script provides a clear, documented template for running distillation on GKE, including helpful auto-resume logic.
  • Comprehensive Testing: The addition of distillation_metrics_test.py provides excellent coverage for the new aggregation logic and numerical stability fixes.


Returns 0.0 for an empty input or when total count is non-positive.
"""
arr = np.asarray(sum_count_pairs, dtype=np.float32)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Using `np.float64` for the internal array in `weighted_mean` is safer for aggregating metrics over a large number of steps or hosts to prevent precision loss in the accumulated sums.
Suggested change
arr = np.asarray(sum_count_pairs, dtype=np.float32)
arr = np.asarray(sum_count_pairs, dtype=np.float64)

Comment on lines +308 to +310
terminal=$(kubectl get jobset "$XPK_WORKLOAD" \
-o jsonpath='{.status.terminalState}' 2>/dev/null || echo "")
if [ -n "$terminal" ]; then
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Consider performing an explicit `xpk workload delete` before the first `submit_workload` attempt in `resume_until_done`. This ensures a clean state if the script is re-run after a previous attempt was interrupted before reaching the cleanup phase.
Suggested change
terminal=$(kubectl get jobset "$XPK_WORKLOAD" \
-o jsonpath='{.status.terminalState}' 2>/dev/null || echo "")
if [ -n "$terminal" ]; then
while [ "$retry" -lt "$MAX_RETRIES" ]; do
echo "=== resume attempt $((retry + 1)) / $MAX_RETRIES (target steps: $target) ==="
# Ensure no stale workload exists before starting the first attempt
if [ "$retry" -eq 0 ]; then
xpk workload delete --workload="$XPK_WORKLOAD" \
--cluster="$XPK_CLUSTER" --project="$XPK_PROJECT" --zone="$XPK_ZONE" --force \
>/dev/null 2>&1 || true
fi
submit_workload

Comment on lines +228 to +231
Used as the aggregation function for metrics emitted by `compute_loss` and
`compute_eval_loss`. Robust to per-host imbalance and to varying valid-token
counts across logging steps:
final_value = sum(sums) / sum(counts)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 It might be helpful to update the docstring to explicitly mention that `weighted_mean` also correctly handles "simple mean" metrics when they are emitted as `(value, 1.0)` pairs.
Suggested change
Used as the aggregation function for metrics emitted by `compute_loss` and
`compute_eval_loss`. Robust to per-host imbalance and to varying valid-token
counts across logging steps:
final_value = sum(sums) / sum(counts)
Used as the aggregation function for metrics emitted by `compute_loss` and
`compute_eval_loss`. Robust to per-host imbalance and to varying valid-token
counts across logging steps. Also handles simple-mean metrics if emitted
as (value, 1.0) pairs:

@gagika gagika force-pushed the agagik-checkpoint branch from 231a16a to 190a12d Compare April 21, 2026 23:36
@copybara-service copybara-service Bot merged commit 25f7ba0 into main Apr 22, 2026
31 checks passed
@copybara-service copybara-service Bot deleted the agagik-checkpoint branch April 22, 2026 01:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants