Skip to content

[NNX] Clean up transient intermediates at setup#4274

Open
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_intermediates
Open

[NNX] Clean up transient intermediates at setup#4274
xibinliu wants to merge 1 commit into
mainfrom
xibin/nnx_intermediates

Conversation

@xibinliu

@xibinliu xibinliu commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Description

nnx.Intermediate in meshed sharding caused two errors:

  • llama3_1_70b_131072_fp8_4x4x4:

    • During model creation, maybe_quantize_model calls qwix.quantize_model(...) which runs a live forward pass on a dummy input of shape, and hidden_states is Sowed: Because num_vocab_tiling: 4 is set, this dummy pass calls self.sow(nnx.Intermediate, "hidden_states", ...) under decoder. This registers hidden_states as an intermediate variable under model.decoder inside state_mesh_shardings.
    • The PyTree Mismatch: At runtime, train_step pops and discards nnx.Intermediate variables before returning. Because JAX expects hidden_states under decoder in out_shardings (since it was present during setup) but the returned state lacks it, compilation fails with the ValueError mismatch.
  • llama3_1_70b_131072_fp8_4x8x8:

    • This workload executes the same qwix.quantize_model(...) dummy pass at initialization, which registers the sowed hidden_states under decoder. However, because this is run on a 4×8×8 topology (256 chips = 512 JAX devices), it results sowed hidden_states activation has a massive shape of bf16[64, 131072, 8192], and JAX's initialization JIT compilation (jax.jit(init_state_partial)) crashes with the XLA RET_CHECK failure at the very beginning of setup

Changes:

This change cleans up transient variables sowed during model creation/tracing from the persistent sharding state and checkpoint layout.

  1. Clean Quantization Setup:

    • Updated maybe_quantize_model in quantizations.py to pop sowed nnx.Intermediate variables in-place immediately after Qwix tracing.
    • This cleans the initial state returned by the model factory fn, allowing us to revert the setup-time nnx.Not(nnx.Intermediate) filters in maxtext_utils.py and train_compile.py.
    • Optimized create_nnx_abstract_model in model_creation_utils.py to reuse abs_model.mesh and avoid calling eval_shape twice.
  2. Metrics Sowing Uniformity:

    • Changed all self.sow("intermediates", ...) calls to self.sow(nnx.Intermediate, ...) in all pure NNX models (gemma.py, llama2.py, etc.).
    • Sowing string-based "intermediates" was dynamically creating custom Variable types that did not inherit from nnx.Intermediate, leading to parameter/checkpoint bloat.
    • train_step in train.py now strips all intermediates via nnx.Not(nnx.Intermediate) before returning.
  3. Testing:

    • Added MaybeQuantizeModelTest in quantizations_test.py to assert that sowed intermediates are popped from the model state and that abstract model state contains no intermediates.

Tests

Validated that the NNX training step compiles successfully without structure or sharding layout mismatches::
Command: python3 src/maxtext/trainers/pre_train/train_compile.py
Configuration: model_name=deepseek3-tiny + attention=dot_product + pure_nnx=true + use_qwix_quantization=true + use_qk_clip=true

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 25, 2026

Copy link
Copy Markdown

…ing as nnx.Intermediate

This change cleans up transient variables sowed during model creation/tracing from the persistent sharding state and checkpoint layout.

1. Clean Quantization Setup:
- Updated maybe_quantize_model in quantizations.py to pop sowed nnx.Intermediate variables in-place immediately after Qwix tracing.
- This cleans the initial state returned by the model factory fn, allowing us to revert the setup-time nnx.Not(nnx.Intermediate) filters in maxtext_utils.py and train_compile.py.
- Optimized create_nnx_abstract_model in model_creation_utils.py to reuse abs_model.mesh and avoid calling eval_shape twice.

2. Metrics Sowing Uniformity:
- Changed all self.sow("intermediates", ...) calls to self.sow(nnx.Intermediate, ...) in all pure NNX models (gemma.py, llama2.py, etc.).
- Sowing string-based "intermediates" was dynamically creating custom Variable types that did not inherit from nnx.Intermediate, leading to parameter/checkpoint bloat.
- train_step in train.py now strips all intermediates via nnx.Not(nnx.Intermediate) before returning.

3. Testing:
- Added MaybeQuantizeModelTest in quantizations_test.py to assert that sowed intermediates are popped from the model state and that abstract model state contains no intermediates.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant