Skip to content

Resolve step timing misattribution and isolate checkpoint overhead in MaxText.#4270

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_936672580
Open

Resolve step timing misattribution and isolate checkpoint overhead in MaxText.#4270
copybara-service[bot] wants to merge 1 commit into
mainfrom
test_936672580

Conversation

@copybara-service

@copybara-service copybara-service Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Resolve step timing misattribution and isolate checkpoint overhead in MaxText.

Previously, MaxText measured training step times using host-side wall-clock intervals between loop iterations (now() - last_step_completion). Because JAX dispatches steps asynchronously and the metric logger flushes metrics buffered from step N-1, this logic was broken in several ways:

  1. Inaccurate Step 0 Time: The reported time for step 0 was artificially low because it only measured the asynchronous dispatch time.
  2. Attribution Shift: Measuring now() - last_step_completion effectively measured the time taken for step N-1 to complete, but attributed it to step N+1.
  3. Missing Overheads: Checkpointing, HLO dumps, eval, and profiling times were not included in the step time calculations.
  4. Broken Checkpoint Logs: Checkpointing acts as a synchronization barrier for step N. Because checkpointing blocked the host to perform D2H tensor transfers before the flush, it broke the logs around checkpoint boundaries, causing subsequent steps to enqueue instantly and artificially stall later.

We fix this by measuring the interval between log flushes, where the loss is actually synchronized:

  • Eager Loss Synchronization before Step Time calculation: We force host-device synchronization on the main thread by converting the loss JAX array to a float inside _flush_one_buffered_entry. The step time is now calculated as the duration between consecutive eager loss synchronizations. Note that this synchronization was already occurring subsequently during final metrics serialization; moving it forward simply establishes a clean boundary to accurately measure step times.
  • Loop Reordering: buffer_and_write_metrics is moved before maybe_save_checkpoint. This ensures that metrics for step N-1 are flushed before step N's checkpointing sync barrier blocks the host thread.
  • Timer Cleanup: Inaccurate host-loop timing variables (last_step_completion and step_time_delta) are removed.

This guarantees that all physical blocking copy overhead and host CPU serialization contention are accurately captured and cleanly attributed to the correct steps, eliminating the anomalies around step 0 and checkpoint boundaries.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 61.11111% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/common/metric_logger.py 62.50% 6 Missing ⚠️
src/maxtext/trainers/pre_train/train.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@copybara-service copybara-service Bot force-pushed the test_936672580 branch 4 times, most recently from e4e574c to cdc765a Compare June 25, 2026 11:43
… MaxText.

Previously, MaxText measured training step times using host-side wall-clock intervals between loop iterations (`now() - last_step_completion`). Because JAX dispatches steps asynchronously and the metric logger flushes metrics buffered from step N-1, this logic was broken in several ways:

1. Inaccurate Step 0 Time: The reported time for step 0 was artificially low because it only measured the asynchronous dispatch time.
2. Attribution Shift: Measuring `now() - last_step_completion` effectively measured the time taken for step N-1 to complete, but attributed it to step N+1.
3. Missing Overheads: Checkpointing, HLO dumps, eval, and profiling times were not included in the step time calculations.
4. Broken Checkpoint Logs: Checkpointing acts as a synchronization barrier for step N. Because checkpointing blocked the host to perform D2H tensor transfers before the flush, it broke the logs around checkpoint boundaries, causing subsequent steps to enqueue instantly and artificially stall later.

We fix this by measuring the interval between log flushes, where the loss is actually synchronized:

- Eager Loss Synchronization before Step Time calculation: We force host-device synchronization on the main thread by converting the loss JAX array to a float inside `_flush_one_buffered_entry`. The step time is now calculated as the duration between consecutive eager loss synchronizations. Note that this synchronization was already occurring subsequently during final metrics serialization; moving it forward simply establishes a clean boundary to accurately measure step times.
- Loop Reordering: `buffer_and_write_metrics` is moved before `maybe_save_checkpoint`. This ensures that metrics for step N-1 are flushed before step N's checkpointing sync barrier blocks the host thread.
- Timer Cleanup: Inaccurate host-loop timing variables (`last_step_completion` and `step_time_delta`) are removed.

This guarantees that all physical blocking copy overhead and host CPU serialization contention are accurately captured and cleanly attributed to the correct steps, eliminating the anomalies around step 0 and checkpoint boundaries.

# Checklist

Before submitting this PR, please make sure (put X in square brackets):
- [x] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label.
- [x] I have necessary comments in my code, particularly in hard-to-understand areas.
- [x] I have run end-to-end tests tests and provided workload links above if applicable.
- [x] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

PiperOrigin-RevId: 936672580
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant