Skip to content

feat: migrate pipeline to nnx#2885

Open
mesakhcienet wants to merge 39 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx
Open

feat: migrate pipeline to nnx#2885
mesakhcienet wants to merge 39 commits into
AI-Hypercomputer:mainfrom
CIeNET-International:test/pipeline-scan-nnx

Conversation

@mesakhcienet
Copy link
Copy Markdown
Contributor

@mesakhcienet mesakhcienet commented Dec 24, 2025

Description

implement nnx-based pipeline.

This PR extends PR#2831

Main changes:

  1. nnx_decoders.py: implementing the missing pipeline logic in nnx_decoders.py.
  2. pipeline.py : add a new class NNXPipeline, which is a nnx-based pipeline class.

Tests

we run the pipeline process with command below:

MODEL_NAME=llama2-7b
python -m MaxText.train src/maxtext/configs/base.yml \
    run_name=pipeline_test_${MODEL_NAME}_nnx \
    base_output_directory=/dev/shm/pipeline_test_nnx \
    model_name=${MODEL_NAME}\
    dataset_type=synthetic \
    steps=15 \
    debug_sharding=true \
    per_device_batch_size=2 \
    max_target_length=32 \
    ici_pipeline_parallelism=2 \
    num_pipeline_microbatches=4 \
    num_layers_per_pipeline_stage=2 \
    enable_checkpointing=false \
    enable_nnx=true \
    pure_nnx_decoder=true \
    scan_layers_per_stage=false \
    async_checkpointing=false > nnx-porting-log/pipeline/custom_${MODEL_NAME}.log 2>&1

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@mesakhcienet mesakhcienet changed the title core: migrate pipeline to nnx feat: migrate pipeline to nnx Dec 24, 2025
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 8 times, most recently from 6875da8 to f34b1a3 Compare January 15, 2026 23:43
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 19, 2026

@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 4 times, most recently from 12a3907 to 2c16599 Compare January 28, 2026 08:04
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch 2 times, most recently from 64dc147 to 9e4518e Compare February 2, 2026 01:58
@mesakhcienet mesakhcienet force-pushed the test/pipeline-scan-nnx branch from 631a73e to ac97a1d Compare March 2, 2026 08:48
@mesakhcienet mesakhcienet changed the base branch from main to xibin/nnx_all March 2, 2026 08:48
@ecnal-cienet ecnal-cienet force-pushed the xibin/nnx_all branch 12 times, most recently from 1849f0b to 669dc01 Compare March 3, 2026 19:59
Add matching [PIPELINE-DIAG] and [DECODER-DIAG] tagged logging to NNX
pipeline.py and decoders.py. Logs setup config, nnx.split state
partitions, L1/L2/L3 custom_vjp residuals/leaf counts, outer scan
carry structure (2 elements vs Linen's 3), vmap output structure,
BSW buffer shapes, and closure-captured variables.

fix: address challenger gaps in NNX diagnostic logging

- C-1: Add total carry leaf count + total closure leaf count + GB
- C-3: Add scatter_update flag, checkpoint policy at use site
- C-4: Log to_linen_class wrapper in create_pipeline, stage_factory
  pattern in decoders.py
Move jax.checkpoint from wrapping outer_body (outside scan) to wrapping
execute_pipeline_repeat (inside scan body). This matches Linen's
nn.scan(nn.remat(stage_fn)) pattern which creates a closed_call
boundary per iteration.

Root cause: Linen's pattern lets XLA unroll trip-1 outer loops, inline
the closed_call, and clone inner while-loops 16x — producing 48 small
loops with small carries. NNX's pattern (checkpoint outside scan)
creates trip-16 loops that XLA won't unroll (above threshold), resulting
in 5 monolithic loops with large carries and poor buffer reuse.

HLO evidence:
- Linen: 8 -> 48 while-loops, preallocated-temp 12.15 GiB
- NNX:   7 ->  5 while-loops, preallocated-temp 13.80 GiB

Same fix applied to bubble iterations.
Revert to exact 0506 scan structure (no unroll parameter) but with
dual-buffer BSW fix for numerical correctness:

1. outer_body: bsw_ref[0] = (cur_bsw, nxt_bsw) — 2 all-gathers per
   repeat so trailing stages get correct repeat's weights at boundaries
2. bubble: same dual-buffer pattern
3. Remove jax.lax.scan(unroll=N) — match 0506 behavior (unroll=1)
4. Restore bsw[0] is bsw[1] fast path in get_current_weights_from_bsw

The dual-buffer is required for correctness: with (cur_bsw, cur_bsw),
all stages get the same repeat's weights, but trailing stages at repeat
boundaries need the previous repeat's weights.
Replace dual-buffer (2 all-gathers per repeat, +8 GB) with Linen's
pattern: carry w_curr through outer scan, compute w_next via single
weight_prefetching call per repeat.

- outer carry: (loop_state, layer_mutables, w_curr)
- outer_body: w_next = weight_prefetching(iteration), BSW = (w_curr, w_next)
- w_next becomes next iteration's w_curr via carry
- Bubbles: (final_w_curr, final_w_curr) — all stages on same repeat

This matches Linen's create_pipeline_stage pattern exactly:
1 all-gather per repeat (not 2), w_curr carried (not recomputed).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants