diff --git a/src/maxtext/layers/multi_token_prediction.py b/src/maxtext/layers/multi_token_prediction.py index bb2cf90f37..6f9af709fd 100644 --- a/src/maxtext/layers/multi_token_prediction.py +++ b/src/maxtext/layers/multi_token_prediction.py @@ -37,11 +37,11 @@ # Custom Variable types for MTP intermediate outputs # These will be automatically converted to Linen mutable collections by ToLinen wrapper # The class names become collection names directly (no case conversion) -class mtp_losses(nnx.Variable): # pylint: disable=invalid-name +class mtp_losses(nnx.Intermediate): # pylint: disable=invalid-name """Variable type for storing MTP loss components -> 'mtp_losses' collection.""" -class mtp_acceptance(nnx.Variable): # pylint: disable=invalid-name +class mtp_acceptance(nnx.Intermediate): # pylint: disable=invalid-name """Variable type for storing MTP acceptance predictions -> 'mtp_acceptance' collection.""" diff --git a/src/maxtext/layers/quantizations.py b/src/maxtext/layers/quantizations.py index 2c5f638f8c..17956f07fb 100644 --- a/src/maxtext/layers/quantizations.py +++ b/src/maxtext/layers/quantizations.py @@ -862,6 +862,10 @@ def maybe_quantize_model(model, config): dummy_segment_ids, enable_dropout=False, ) + # Qwix quantization runs a forward pass during tracing, which sows transient nnx.Intermediate variables + # (e.g. max_logits from QK-Clip, MTP losses) into the model. Popping them here prevents structural mismatches + # between the initial setup GraphDef/state_mesh_shardings and the stripped states during train steps. + nnx.pop(model, nnx.Intermediate) else: model = qwix.quantize_model(model, quantization_provider) return model diff --git a/src/maxtext/models/gemma.py b/src/maxtext/models/gemma.py index 84f4f6817d..47f2a3da1b 100644 --- a/src/maxtext/models/gemma.py +++ b/src/maxtext/models/gemma.py @@ -171,10 +171,10 @@ def __call__( ) if self.config.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/gemma2.py b/src/maxtext/models/gemma2.py index a7315763eb..f62709e05f 100644 --- a/src/maxtext/models/gemma2.py +++ b/src/maxtext/models/gemma2.py @@ -307,10 +307,10 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if self.config.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/gemma3.py b/src/maxtext/models/gemma3.py index 344d98ed88..2d2b36a431 100644 --- a/src/maxtext/models/gemma3.py +++ b/src/maxtext/models/gemma3.py @@ -242,10 +242,10 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/gemma4.py b/src/maxtext/models/gemma4.py index 626d2ff54c..8cf927629c 100644 --- a/src/maxtext/models/gemma4.py +++ b/src/maxtext/models/gemma4.py @@ -365,7 +365,7 @@ def __call__( if getattr(self.config, "num_experts", 1) > 1: mlp_lnx, load_balance_loss, _ = self.mlp(attn_output, original_inputs=attention_lnx) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) else: mlp_lnx = self.mlp(attn_output, deterministic=deterministic) @@ -381,10 +381,10 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index a6b08d8b24..7cf8fbd773 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -510,10 +510,10 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if self.config.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index e854a75556..de2f364de6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -198,13 +198,13 @@ def __call__( ) if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index 0c3e0cca7c..eeb54da934 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -203,10 +203,10 @@ def __call__( layer_output = self._maybe_shard_with_logical(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/llama4.py b/src/maxtext/models/llama4.py index 66dea4295c..26fd4d322d 100644 --- a/src/maxtext/models/llama4.py +++ b/src/maxtext/models/llama4.py @@ -497,13 +497,13 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/mistral.py b/src/maxtext/models/mistral.py index 49a74c95db..fa7f0956d4 100644 --- a/src/maxtext/models/mistral.py +++ b/src/maxtext/models/mistral.py @@ -177,10 +177,10 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/mixtral.py b/src/maxtext/models/mixtral.py index faf69273c6..621ff5c619 100644 --- a/src/maxtext/models/mixtral.py +++ b/src/maxtext/models/mixtral.py @@ -176,13 +176,13 @@ def __call__( layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) if self.config.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index fe8a4e489e..3a4f526159 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -212,10 +212,10 @@ def __call__( ) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index de8bd1cf18..331df17f41 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -232,7 +232,7 @@ def __call__( # We sow the load balancing loss so it can be collected and added to the total loss # during training. if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) # Final residual connection (after the MoE block) layer_output = residual + mlp_output diff --git a/src/maxtext/models/qwen3_custom.py b/src/maxtext/models/qwen3_custom.py index e0ca4bb512..9b6f6c037e 100644 --- a/src/maxtext/models/qwen3_custom.py +++ b/src/maxtext/models/qwen3_custom.py @@ -243,7 +243,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names) if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None: - self.sow("intermediates", "moe_lb_loss", load_balance_loss) + self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss) layer_output = mlp_lnx if self.layer_up_projection is not None: diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 74c41e8e90..b2c06aaeef 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -560,7 +560,7 @@ def move(path, value): # Drop Intermediates (e.g. sowed max_logits for QK-Clip) and the MTP sown # vars (mtp_losses/mtp_acceptance) before returning. They're absent from # state_mesh_shardings and would cause a leaf-count / structure mismatch. - return nnx.state(new_state, nnx.Not(nnx.Any(nnx.Intermediate, mtp_losses, mtp_acceptance))), metrics + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index fe34d9fac0..f81528daf6 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1648,18 +1648,14 @@ def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=Tru # ourselves via nnx_construct_named_sharding, so auto-assignment is not needed here. abs_model = nnx.eval_shape(nnx_init_trainstate_fn) _, abs_var_state = nnx.split(abs_model) - named_sharding_state = sharding.nnx_construct_named_sharding( - abs_var_state, mesh - ) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, named_sharding_state, ) - state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding( - abstract_state - ) + state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state) if is_training and config.shard_optimizer_over_data: # Add data to sharding for optimizer state diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index f93fefd046..c0248783eb 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -574,9 +574,6 @@ def create_nnx_abstract_model( with nn.logical_axis_rules(config.logical_axis_rules): _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key, quant_mode_str=quant_mode_str) - if mesh is None: - _tmp = nnx.eval_shape(_create_model) - mesh = _tmp.mesh # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of # nnx.get_abstract_model, which uses get_var_pspec internally and ignores # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, @@ -586,10 +583,10 @@ def create_nnx_abstract_model( # AbstractMesh). Sharding is resolved afterwards via the helper, so the # wrap is unnecessary here. abs_model = nnx.eval_shape(_create_model) + if mesh is None: + mesh = abs_model.mesh graphdef, abs_var_state = nnx.split(abs_model) - named_sharding_state = sharding.nnx_construct_named_sharding( - abs_var_state, mesh - ) + named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh) abstract_state = jax.tree.map( lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), abs_var_state, diff --git a/tests/unit/quantizations_test.py b/tests/unit/quantizations_test.py index fff430e646..9335b1433b 100644 --- a/tests/unit/quantizations_test.py +++ b/tests/unit/quantizations_test.py @@ -652,5 +652,68 @@ def test_gmm_kernel(group_sizes, k, n, tiling, dtype): assert jnp.abs(quant_out - base_out).mean() / jnp.abs(base_out).mean() < 2e-1 +class MaybeQuantizeModelTest(unittest.TestCase): + + def test_maybe_quantize_model_pops_intermediates(self): + config = pyconfig.initialize( + [None, get_test_config_path()], + enable_checkpointing=False, + quantization="int8", + use_qwix_quantization=True, + use_batch_split_schedule=False, + pure_nnx=True, + micro_batch_size_to_train_on=1, + max_target_length=2, + ) + + class DummyModel(nnx.Module): + + def __init__(self): + self.param = nnx.Param(jnp.ones((2, 2))) + + def __call__(self, tokens, positions, segment_ids, enable_dropout=False): + self.sow(nnx.Intermediate, "some_metric", jnp.mean(self.param.get_value())) + return tokens.astype(jnp.float32) @ self.param.get_value() + + model = DummyModel() + + # Verify that before quantizing, there are no intermediates sowed yet + _, state = nnx.split(model) + self.assertNotIn("intermediates", state.to_pure_dict()) + + # 1. Run maybe_quantize_model (which runs Qwix tracing and then pops intermediates in-place) + quantized_model = quantizations.maybe_quantize_model(model, config) + + # 2. Extract state to check if intermediates exist + _, state = nnx.split(quantized_model) + state_dict = state.to_pure_dict() + + # Assert that intermediates collection does not exist in the state + self.assertNotIn("intermediates", state_dict) + + def test_nnx_abstract_state_has_no_intermediates(self): + # Initialize a configuration with Qwix quantization enabled + config = pyconfig.initialize( + [None, get_test_config_path()], + enable_checkpointing=False, + model_name="deepseek3-tiny", + attention="dot_product", + pure_nnx=True, + use_qwix_quantization=True, + use_qk_clip=True, # This sows QK clip intermediates during the forward pass + ) + + # Create the abstract model + mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context")) + _, abstract_model = model_creation_utils.create_nnx_abstract_model(config, mesh) + + # Split model to extract its state dict + _, state = nnx.split(abstract_model) + state_dict = state.to_pure_dict() + + # Assert that intermediates collection is NOT present in the abstract state + self.assertNotIn("intermediates", state_dict) + + if __name__ == "__main__": unittest.main()