distillation: resume + xpk launcher + metrics refactor#3701
distillation: resume + xpk launcher + metrics refactor#3701copybara-service[bot] merged 1 commit intomainfrom
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
3d0fa5d to
d5da820
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR significantly improves the robustness and observability of the distillation trainer. The transition to a token-weighted metrics contract (sum, count) ensures unbiased logging in multi-host environments, and the delegation of checkpointing to Tunix (while preserving MaxText's Grain handling) simplifies the codebase.
🔍 General Feedback
- Numerical Stability: The addition of
_PPL_CE_CAPis a great best practice for preventinginfin logged perplexity metrics. - Testing: The new unit tests in
distillation_metrics_test.pyare comprehensive, covering KL direction, self-equivalence, and sharded aggregation. - Launcher: The new XPK script is well-documented and provides a clear path for users to run distillation at scale.
| # Distillation Strategy | ||
| # ----------------------------------------------------------------------------- | ||
|
|
||
| # Clamp CE before exp() so a divergence spike doesn't poison PPL averages |
There was a problem hiding this comment.
🟢 This cap is a sensible addition for numerical stability. 20 nats (~4.8e8 PPL) is a safe upper bound that prevents a single divergent token from poisoning the entire window's average with inf while remaining well below the fp32 exp overflow limit (~88).
| # Per-token validity mask, derived from the one-hot labels so we don't need | ||
| # a separate mask input. A padded (fully-zero) row yields `any != 0 == False`. | ||
| mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32) # [B, T] | ||
| valid_count = jnp.sum(mask) |
There was a problem hiding this comment.
🟡 Using jnp.any(labels != 0, axis=-1) works correctly for one-hot labels to identify valid tokens. However, since the DistillationStrategy class already tracks self.pad_id, it might be more robust to explicitly use that if the label format ever changes from one-hot in the future. Given the current design, this is correct.
| # Clamp CE before exp() so a divergence spike doesn't poison PPL averages | ||
| # with inf. 20 nats is well above plausible CE (Llama random-init ~11.76) | ||
| # and far below fp32 exp overflow (~88). | ||
| _PPL_CE_CAP = 20.0 |
There was a problem hiding this comment.
🟢 This cap is a sensible addition for numerical stability. 20 nats (~4.8e8 PPL) is a safe upper bound that prevents a single divergent token from poisoning the entire window's average with inf while remaining well below the fp32 exp overflow limit (~88).
| feature_loss = jnp.array(0.0) | ||
| # Per-token validity mask, derived from the one-hot labels so we don't need | ||
| # a separate mask input. A padded (fully-zero) row yields `any != 0 == False`. | ||
| mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32) # [B, T] |
There was a problem hiding this comment.
🟡 Using jnp.any(labels != 0, axis=-1) works correctly for one-hot labels to identify valid tokens. However, since the DistillationStrategy class already tracks self.pad_id, it might be more robust to explicitly use that if the label format ever changes from one-hot in the future. Given the current design, this is correct.
| RUN pip install --no-cache-dir --force-reinstall --no-deps \\ | ||
| "jax==$JAX_PIN" "jaxlib==$JAXLIB_PIN" "libtpu==$LIBTPU_PIN" | ||
| EOF | ||
| sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp" |
There was a problem hiding this comment.
🟢 This provides a great starting point for users. Overwriting the local $XPK_BASE_IMAGE tag is efficient for local development, but adding a note that it modifies the local image state (which you've done in the comments) is good practice.
| sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp" | |
| sudo docker build -t "$XPK_BASE_IMAGE" -f "$tmp/Dockerfile" "$tmp" |
d5da820 to
2020232
Compare
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
f9d0ec0 to
2df1ca1
Compare
2df1ca1 to
e2b20c3
Compare
e2b20c3 to
231a16a
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces significant improvements to the MaxText distillation trainer, primarily focusing on metric aggregation and checkpointing. The shift to a token-weighted (sum, count) aggregation contract ensures unbiased metrics across multi-host setups, and the delegation to upstream Tunix for checkpointing simplifies the maintenance of restoration logic.
🔍 General Feedback
- Robust Metrics: The refactor to
(sum, count)pairs and the introduction ofweighted_meanis a great architectural choice for distributed training stability. - Improved Monitoring: The addition of
kl_div_T1and token-weighted perplexity metrics significantly improves the comparability of runs across different annealing schedules. - XPK Integration: The new
run_distill_xpk.shscript provides a clear, documented template for running distillation on GKE, including helpful auto-resume logic. - Comprehensive Testing: The addition of
distillation_metrics_test.pyprovides excellent coverage for the new aggregation logic and numerical stability fixes.
|
|
||
| Returns 0.0 for an empty input or when total count is non-positive. | ||
| """ | ||
| arr = np.asarray(sum_count_pairs, dtype=np.float32) |
There was a problem hiding this comment.
| arr = np.asarray(sum_count_pairs, dtype=np.float32) | |
| arr = np.asarray(sum_count_pairs, dtype=np.float64) |
| terminal=$(kubectl get jobset "$XPK_WORKLOAD" \ | ||
| -o jsonpath='{.status.terminalState}' 2>/dev/null || echo "") | ||
| if [ -n "$terminal" ]; then |
There was a problem hiding this comment.
| terminal=$(kubectl get jobset "$XPK_WORKLOAD" \ | |
| -o jsonpath='{.status.terminalState}' 2>/dev/null || echo "") | |
| if [ -n "$terminal" ]; then | |
| while [ "$retry" -lt "$MAX_RETRIES" ]; do | |
| echo "=== resume attempt $((retry + 1)) / $MAX_RETRIES (target steps: $target) ===" | |
| # Ensure no stale workload exists before starting the first attempt | |
| if [ "$retry" -eq 0 ]; then | |
| xpk workload delete --workload="$XPK_WORKLOAD" \ | |
| --cluster="$XPK_CLUSTER" --project="$XPK_PROJECT" --zone="$XPK_ZONE" --force \ | |
| >/dev/null 2>&1 || true | |
| fi | |
| submit_workload |
| Used as the aggregation function for metrics emitted by `compute_loss` and | ||
| `compute_eval_loss`. Robust to per-host imbalance and to varying valid-token | ||
| counts across logging steps: | ||
| final_value = sum(sums) / sum(counts) |
There was a problem hiding this comment.
| Used as the aggregation function for metrics emitted by `compute_loss` and | |
| `compute_eval_loss`. Robust to per-host imbalance and to varying valid-token | |
| counts across logging steps: | |
| final_value = sum(sums) / sum(counts) | |
| Used as the aggregation function for metrics emitted by `compute_loss` and | |
| `compute_eval_loss`. Robust to per-host imbalance and to varying valid-token | |
| counts across logging steps. Also handles simple-mean metrics if emitted | |
| as (value, 1.0) pairs: |
231a16a to
190a12d
Compare
Description
Distillation trainer: token-weighted metrics, use Tunix checkpointing, XPK launcher.
(sum, count)contract. The old path emittedscalars aggregated via
np.mean, which biased logs under uneven masks /multi-host shards. New
weighted_meanaggregator computessum(sums) / sum(counts)— unbiased. Addsstudent_perplexity,teacher_perplexity,kl_div_T1. Renamesdistill/kl_div→distill/kl_div_at_T. Clamps CE beforeexp()so divergences don'tpoison PPL averages.
maybe_restorenow justunwraps
ModelBundle → student_model; the local_shard_optimizeroverride is gone (upstream handles replicated scalars via
with_sharding_constraint).scripts/run_distill_xpk.sh— reference XPK launcher withprep_image(layers Tunix branch + repins jax/libtpu),submit,monitor,resume_until_done.Tests
tests/post_training/unit/distillation_metrics_test.py,cpu_only, 14 tests): KL direction + self-equivalence, CE/PPL numerics,pad masking, T² scaling, T=1 KL invariance, aggregator single-host +
4-device sharded, feature-mapping (β>0) layer-indexing, eval contract.
train_distill_test.py): updated for the renamedmetric and the
(sum, count)tuple. 22/22 green.tpu7x-4x4x4: submit to step 150 (checkpoints 50/100/150,loss 1.53 → 1.20), then resume to step 250 — restored at 150, no loss
regression.
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.