Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
# Custom Variable types for MTP intermediate outputs
# These will be automatically converted to Linen mutable collections by ToLinen wrapper
# The class names become collection names directly (no case conversion)
class mtp_losses(nnx.Variable): # pylint: disable=invalid-name
class mtp_losses(nnx.Intermediate): # pylint: disable=invalid-name
"""Variable type for storing MTP loss components -> 'mtp_losses' collection."""


class mtp_acceptance(nnx.Variable): # pylint: disable=invalid-name
class mtp_acceptance(nnx.Intermediate): # pylint: disable=invalid-name
"""Variable type for storing MTP acceptance predictions -> 'mtp_acceptance' collection."""


Expand Down
4 changes: 4 additions & 0 deletions src/maxtext/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,10 @@ def maybe_quantize_model(model, config):
dummy_segment_ids,
enable_dropout=False,
)
# Qwix quantization runs a forward pass during tracing, which sows transient nnx.Intermediate variables
# (e.g. max_logits from QK-Clip, MTP losses) into the model. Popping them here prevents structural mismatches
# between the initial setup GraphDef/state_mesh_shardings and the stripped states during train steps.
nnx.pop(model, nnx.Intermediate)
else:
model = qwix.quantize_model(model, quantization_provider)
return model
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def __call__(
)

if self.config.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/models/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def __call__(
if getattr(self.config, "num_experts", 1) > 1:
mlp_lnx, load_balance_loss, _ = self.mlp(attn_output, original_inputs=attention_lnx)
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)
else:
mlp_lnx = self.mlp(attn_output, deterministic=deterministic)

Expand All @@ -381,10 +381,10 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def __call__(
)

if cfg.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ def __call__(
layer_output = self._maybe_shard_with_logical(layer_output, self.activation_axis_names)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,13 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ def __call__(
layer_output = nn.with_logical_constraint(layer_output, self.activation_axis_names)

if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)

if self.config.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
6 changes: 3 additions & 3 deletions src/maxtext/models/olmo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,10 @@ def __call__(
)

if cfg.record_internal_nn_metrics:
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
self.sow(
"intermediates",
nnx.Intermediate,
"activation_fraction_zero",
jnp.sum(layer_output == 0) / jnp.size(layer_output),
)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __call__(
# We sow the load balancing loss so it can be collected and added to the total loss
# during training.
if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)

# Final residual connection (after the MoE block)
layer_output = residual + mlp_output
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/qwen3_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __call__(
mlp_lnx = nn.with_logical_constraint(mlp_lnx, self.activation_axis_names)

if self.config.load_balance_loss_weight > 0.0 and load_balance_loss is not None:
self.sow("intermediates", "moe_lb_loss", load_balance_loss)
self.sow(nnx.Intermediate, "moe_lb_loss", load_balance_loss)

layer_output = mlp_lnx
if self.layer_up_projection is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def move(path, value):
# Drop Intermediates (e.g. sowed max_logits for QK-Clip) and the MTP sown
# vars (mtp_losses/mtp_acceptance) before returning. They're absent from
# state_mesh_shardings and would cause a leaf-count / structure mismatch.
return nnx.state(new_state, nnx.Not(nnx.Any(nnx.Intermediate, mtp_losses, mtp_acceptance))), metrics
return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics


def eval_step(model, config, state, data, dropout_rng=None):
Expand Down
8 changes: 2 additions & 6 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,18 +1648,14 @@ def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=Tru
# ourselves via nnx_construct_named_sharding, so auto-assignment is not needed here.
abs_model = nnx.eval_shape(nnx_init_trainstate_fn)
_, abs_var_state = nnx.split(abs_model)
named_sharding_state = sharding.nnx_construct_named_sharding(
abs_var_state, mesh
)
named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh)
abstract_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_var_state,
named_sharding_state,
)

state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(
abstract_state
)
state_mesh_shardings = maxtext_utils_nnx.nnx_extract_named_sharding(abstract_state)

if is_training and config.shard_optimizer_over_data:
# Add data to sharding for optimizer state
Expand Down
9 changes: 3 additions & 6 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,6 @@ def create_nnx_abstract_model(

with nn.logical_axis_rules(config.logical_axis_rules):
_create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key, quant_mode_str=quant_mode_str)
if mesh is None:
_tmp = nnx.eval_shape(_create_model)
mesh = _tmp.mesh
# Use nnx.eval_shape + our scan-axis-aware sharding helper instead of
# nnx.get_abstract_model, which uses get_var_pspec internally and ignores
# param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers,
Expand All @@ -586,10 +583,10 @@ def create_nnx_abstract_model(
# AbstractMesh). Sharding is resolved afterwards via the helper, so the
# wrap is unnecessary here.
abs_model = nnx.eval_shape(_create_model)
if mesh is None:
mesh = abs_model.mesh
graphdef, abs_var_state = nnx.split(abs_model)
named_sharding_state = sharding.nnx_construct_named_sharding(
abs_var_state, mesh
)
named_sharding_state = sharding.nnx_construct_named_sharding(abs_var_state, mesh)
abstract_state = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_var_state,
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/quantizations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,68 @@ def test_gmm_kernel(group_sizes, k, n, tiling, dtype):
assert jnp.abs(quant_out - base_out).mean() / jnp.abs(base_out).mean() < 2e-1


class MaybeQuantizeModelTest(unittest.TestCase):

def test_maybe_quantize_model_pops_intermediates(self):
config = pyconfig.initialize(
[None, get_test_config_path()],
enable_checkpointing=False,
quantization="int8",
use_qwix_quantization=True,
use_batch_split_schedule=False,
pure_nnx=True,
micro_batch_size_to_train_on=1,
max_target_length=2,
)

class DummyModel(nnx.Module):

def __init__(self):
self.param = nnx.Param(jnp.ones((2, 2)))

def __call__(self, tokens, positions, segment_ids, enable_dropout=False):
self.sow(nnx.Intermediate, "some_metric", jnp.mean(self.param.get_value()))
return tokens.astype(jnp.float32) @ self.param.get_value()

model = DummyModel()

# Verify that before quantizing, there are no intermediates sowed yet
_, state = nnx.split(model)
self.assertNotIn("intermediates", state.to_pure_dict())

# 1. Run maybe_quantize_model (which runs Qwix tracing and then pops intermediates in-place)
quantized_model = quantizations.maybe_quantize_model(model, config)

# 2. Extract state to check if intermediates exist
_, state = nnx.split(quantized_model)
state_dict = state.to_pure_dict()

# Assert that intermediates collection does not exist in the state
self.assertNotIn("intermediates", state_dict)

def test_nnx_abstract_state_has_no_intermediates(self):
# Initialize a configuration with Qwix quantization enabled
config = pyconfig.initialize(
[None, get_test_config_path()],
enable_checkpointing=False,
model_name="deepseek3-tiny",
attention="dot_product",
pure_nnx=True,
use_qwix_quantization=True,
use_qk_clip=True, # This sows QK clip intermediates during the forward pass
)

# Create the abstract model
mesh = jax.make_mesh((1, 1, 1, 1), ("data", "fsdp", "expert", "context"))
_, abstract_model = model_creation_utils.create_nnx_abstract_model(config, mesh)

# Split model to extract its state dict
_, state = nnx.split(abstract_model)
state_dict = state.to_pure_dict()

# Assert that intermediates collection is NOT present in the abstract state
self.assertNotIn("intermediates", state_dict)


if __name__ == "__main__":
unittest.main()
Loading