diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..7b670dd8d7 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -87,11 +87,10 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) + # Output is Linen-format (keystr_map below uses Linen tree paths). Route to + # Linen regardless of pure_nnx. quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,11 +101,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 406ba92523..c14d87cd4b 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index a0f436dff3..eb8e9890e2 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2897,8 +2897,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..b788ccd13e 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -54,9 +54,12 @@ from jax import random from flax.linen import partitioning as nn_partitioning +from flax import nnx from flax import struct from flax.nnx import TrainState +from maxtext.layers import train_state_nnx + from cloud_tpu_diagnostics import diagnostic from cloud_tpu_diagnostics.configuration import debug_configuration from cloud_tpu_diagnostics.configuration import diagnostic_configuration @@ -85,11 +88,12 @@ from maxtext.experimental.rl import grpo_utils from maxtext.common.metric_logger import MetricLogger from maxtext.common.vertex_tensorboard import VertexTensorboardManager -from maxtext.inference import offline_engine from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding from maxtext.utils import train_utils @@ -335,34 +339,190 @@ def grpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_ return loss, aux +def grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """GRPO loss for the NNX path. + + Signature matches the Linen `grpo_loss_fn` so callers can dispatch on the + same shape. `dropout_rng` and `params` are unused (NNX models carry these + themselves); `reference_model` is the frozen reference `nnx.Module`. The + reference forward is wrapped in `stop_gradient` so grads only flow into the + policy. Returns `(loss, LossAux)`. + """ + del dropout_rng, params # NNX models carry these themselves + + prompt_with_completions = data[f"{config.train_data_columns}_completions"] + prompt_completions_position = data[f"{config.train_data_columns}_completions_position"] + prompt_completions_segmentation = data[f"{config.train_data_columns}_completions_segmentation"] + completions_segmentation = data["ar_completions_segmentation"] + + token_logps_policy, intermediate_outputs = grpo_utils.compute_log_probs_nnx( + policy_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=is_train, + ) + + completion_target_segmentation = data["ar_completions_segmentation"][..., 1:] + valid_seq_mask = completion_target_segmentation != 0 + + rewards = grpo_utils.dummy_reward_len(valid_seq_mask) + rewards = jnp.array(rewards) + + G = config.num_generations + rewards_grouped = rewards.reshape(-1, G) + group_mean = jnp.mean(rewards_grouped, axis=1) + group_std = jnp.std(rewards_grouped, axis=1) + repeated_group_mean = jnp.repeat(group_mean, G) + repeated_group_std = jnp.repeat(group_std, G) + advantages = (rewards - repeated_group_mean) / (repeated_group_std + EPS) + advantages_exp = advantages[:, None] + + if data["completions_logprobs"] is None: # off-policy + old_per_token_logps = jax.lax.stop_gradient(token_logps_policy) + else: # on-policy + old_per_token_logps = data["completions_logprobs"] + + policy_diff = token_logps_policy - old_per_token_logps + coef_1 = jnp.exp(policy_diff) + coef_2 = jnp.clip(coef_1, 1 - config.grpo_epsilon, 1 + config.grpo_epsilon) + loss_tokens = -jnp.minimum(coef_1 * advantages_exp, coef_2 * advantages_exp) + + if config.grpo_beta != 0.0: + token_logps_ref, _ = grpo_utils.compute_log_probs_nnx( + reference_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=False, + ) + token_logps_ref = jax.lax.stop_gradient(token_logps_ref) + token_diff_logps_ref_policy = token_logps_ref - token_logps_policy + per_token_kl = jnp.exp(token_diff_logps_ref_policy) - token_diff_logps_ref_policy - 1 + per_token_kl = per_token_kl * valid_seq_mask + loss_tokens += config.grpo_beta * per_token_kl + + loss_per_example = jnp.sum(loss_tokens * valid_seq_mask, axis=1) / jnp.clip(jnp.sum(valid_seq_mask, axis=1), min=1) + loss = jnp.mean(loss_per_example) + total_weights = jnp.sum(valid_seq_mask) + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + + if config.grpo_beta != 0.0: + avg_kl = jnp.mean((per_token_kl * valid_seq_mask) / jnp.clip(jnp.sum(valid_seq_mask, axis=1, keepdims=True), min=1)) + else: + avg_kl = None + avg_completion_length = jnp.mean(jnp.sum(data["ar_completions_segmentation"] != 0, axis=1)) + aux = LossAux( + total_loss=loss, + avg_reward=jnp.mean(rewards), + avg_reward_std=jnp.mean(repeated_group_std), + avg_advantage=jnp.mean(advantages), + avg_kl=avg_kl, + completion_length=avg_completion_length, + moe_lb_loss=moe_lb_loss, + total_weights=total_weights, + ) + return loss, aux + + # ----------------------------------------------------------------------------- # Trainer and top level training functions # ----------------------------------------------------------------------------- -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """Performs a single training step of the GRPO algorithm. +def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): + """GRPO train_step body for the NNX path. - This function computes the GRPO loss, calculates gradients, and updates the - model's parameters. It handles gradient accumulation and clipping as configured. - The reference model's parameters are held constant during the update. + Reconstructs `TrainStateNNX` from `(model_graphdef, state)`, splits out + the policy params for value_and_grad, applies gradients, and returns the + new state with `nnx.Intermediate` filtered out (transient sown values + must not persist across steps). + """ + del state_mesh_shardings # host-offload paths not yet wired up here - Args: - model: The transformer model to be trained. - config: The training configuration object. - state_mesh_shardings: Pytree of sharding specifications for the training state. - params_shardings: Pytree of sharding specifications for the model parameters. - This argument is not used and is kept to match the signature of other trainers. - state: The current training state, including parameters and optimizer state. - data: A batch of training data, including prompts and generated completions. - dropout_rng: JAX PRNG key for dropout. + if config.gradient_accumulation_steps > 1: + raise NotImplementedError( + "GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. " + "Set gradient_accumulation_steps=1 or pure_nnx=False." + ) - Returns: - A tuple containing: - - new_state: The updated training state after applying gradients. - - metrics: A dictionary of metrics for logging, including loss, reward, - and gradient norms. + state = nnx.merge(model_graphdef, state) # reconstruct TrainStateNNX + policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(policy_graphdef, param, rest, copy=True) + loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + state.apply_gradients(grads) + new_state = state + + scalar_metrics = { + "learning/loss": loss, + "learning/avg_reward": aux.avg_reward, + "learning/avg_reward_std": aux.avg_reward_std, + "learning/avg_advantage": aux.avg_advantage, + "learning/avg_kl": aux.avg_kl, + "learning/completion_length": aux.completion_length, + "learning/moe_lb_loss": aux.moe_lb_loss, + "learning/total_weights": aux.total_weights, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + } + _, new_policy_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) + metrics = {"scalar": scalar_metrics, "scalars": {}} + + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics + + +def _eval_step_nnx(model_graphdef, config, state, data): + """GRPO eval_step body for the NNX path. No state update.""" + state = nnx.merge(model_graphdef, state) + loss, aux = grpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": aux.total_loss, + "evaluation/total_weights": aux.total_weights, + "evaluation/moe_lb_loss": aux.moe_lb_loss, + }, + } + return metrics + + +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): + """Single GRPO training step. + + Computes the GRPO loss, gradients, and applies them to the policy. The + reference model is held constant. Linen and NNX paths split below; on + NNX, `model` is a GraphDef and `state` is a flat `nnx.State` of a + `TrainStateNNX` (with `model`, `optimizer`, and `reference_model`). + + Returns `(new_state, metrics)`. """ + if config.pure_nnx: + return _train_step_nnx(model, config, state_mesh_shardings, state, data) + state, reference_params = _split_grpo_state(state) state_mesh_shardings, reference_params_sharding = _split_grpo_state(state_mesh_shardings) extra_grpo_args = [reference_params] @@ -473,6 +633,8 @@ def eval_step(model, config, state, data, dropout_rng): Returns: A dictionary of evaluation metrics. """ + if config.pure_nnx: + return _eval_step_nnx(model, config, state, data) reference_params, extra_grpo_args, _loss_fn = [], [], grpo_loss_fn state, reference_params = _split_grpo_state(state) @@ -542,27 +704,48 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + if config.pure_nnx != config_inference.pure_nnx: + raise ValueError( + f"config.pure_nnx ({config.pure_nnx}) and config_inference.pure_nnx " f"({config_inference.pure_nnx}) must agree." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] + init_rng = jax.random.PRNGKey(config.init_weights_seed) + if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + training_mesh = maxtext_utils.get_mesh_from_config(config, devices=training_devices) + training_rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = mt.from_config(config, devices=training_devices, mesh=training_mesh, rngs=training_rngs) else: model = mt.from_config(config, devices=training_devices) mesh = model.mesh + max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + inference_mesh_obj = maxtext_utils.get_mesh_from_config(config_inference, devices=inference_devices) + inference_rngs = maxtext_utils_nnx.create_nnx_rngs(config_inference, rng_key=init_rng) + inference_model = mt.from_config( + config_inference, devices=inference_devices, mesh=inference_mesh_obj, rngs=inference_rngs + ) else: inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh, devices=training_devices) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + # Reference uses the same init seed so it starts identical to the policy. + reference_model = _create_model_partial() + return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model) + else: init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) @@ -573,10 +756,15 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_inference_partial, _ = model_creation_utils.create_nnx_abstract_model( + config_inference, inference_mesh, devices=inference_devices + ) + + def init_inference_state_fn(): + inference_nnx_model = _create_inference_partial() + return train_state_nnx.TrainStateNNX(inference_nnx_model, None) + else: init_inference_state_fn = functools.partial( maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng @@ -586,7 +774,11 @@ def setup_train_loop( )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + if config.pure_nnx: + _, params_for_check, _ = nnx.split(state.model, nnx.Param, ...) + sharding.assert_params_sufficiently_sharded(params_for_check, mesh, config.sharding_tolerance) + else: + sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) return ( init_rng, @@ -699,10 +891,15 @@ def train_loop(config, config_inference, recorder, state=None): token=config.hf_access_token, ) - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_grpo_state(state, reference_params) - state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if config.pure_nnx: + # `reference_model` is set up by init_state_fn as a sibling field — nothing to merge. + if not hasattr(state, "reference_model"): + raise RuntimeError("NNX GRPO state is missing reference_model; check setup_train_loop.") + else: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_grpo_state(state, reference_params) + state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator @@ -710,6 +907,9 @@ def train_loop(config, config_inference, recorder, state=None): data_sharding = sharding.get_input_data_sharding(config, mesh) + # Lazy import: pulls in maxengine and jetstream stubs. + from maxtext.inference import offline_engine # pylint: disable=import-outside-toplevel + inference_engine = offline_engine.OfflineEngine( config=config_inference, mesh=inference_mesh, @@ -724,7 +924,11 @@ def train_loop(config, config_inference, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params["params"]) + if config.pure_nnx: + _, _params_for_metrics, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(_params_for_metrics) + else: + metric_logger.write_setup_info_to_tensorboard(state.params["params"]) def generation_worker_fn( worker_inference_engine, @@ -848,21 +1052,32 @@ def generation_worker_fn( state, metrics = p_train_step(state, example_batch, train_rng) with jax.profiler.StepTraceAnnotation("transfer data", step_num=step): if step != 0 and step % config.inference_rollouts == 0: - grpo_utils.pathways_reshard( - config_inference, - inference_engine, - {"params": state.params["params"]}, - {"params": state_mesh_shardings.params["params"]}, - mesh, - {"params": inference_state_mesh_shardings.params["params"]}, - ) + if config.pure_nnx: + grpo_utils.pathways_reshard_nnx( + config_inference, + inference_engine, + state.model, + state_mesh_shardings.model, + inference_state_mesh_shardings.model, + ) + else: + grpo_utils.pathways_reshard( + config_inference, + inference_engine, + {"params": state.params["params"]}, + {"params": state_mesh_shardings.params["params"]}, + mesh, + {"params": inference_state_mesh_shardings.params["params"]}, + ) with data_buffer_lock: data_buffer.clear() step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = _split_grpo_state(state)[0] + # Linen embeds reference in `state.params` and strips it for save; NNX + # holds it as a sibling field on TrainStateNNX so the whole state goes. + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == start_step: @@ -900,7 +1115,7 @@ def generation_worker_fn( metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = _split_grpo_state(state)[0] + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) elif checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 352a2b3b8d..34d437867e 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -21,8 +21,9 @@ import jaxtyping from typing import Any, Callable +from flax import nnx + from maxtext.common.common_types import DecoderBlockType -from maxtext.inference.offline_engine import InputData from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -112,6 +113,48 @@ def compute_log_probs( return token_log_probs, intermediate_outputs +def compute_log_probs_nnx( + model, + inputs, + inputs_position, + inputs_segmentation, + completion_segmentation, + config, + is_train=False, +): + """`compute_log_probs` for the NNX path. + + `model` is an `nnx.Module` (carries its own params + RNG state), so there's + no `params` arg. Intermediates are pulled off the model after the forward + via `nnx.state(model, nnx.Intermediate).to_pure_dict()`. + """ + logits = model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=(config.enable_dropout if is_train else False), + ) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + logits = logits / config.decode_sampling_temperature + + targets = inputs[:, 1:] + shifted_completion_segmentation = jax.lax.dynamic_slice( + completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1) + ) + shifted_completion_segmentation = jnp.pad( + shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0 + ) + mask = shifted_completion_segmentation[..., None] + mask = jnp.broadcast_to(mask, logits.shape) + masked_logits = jnp.where(mask, logits, -jnp.inf) + log_probs = jax.nn.log_softmax(masked_logits, axis=-1) + log_probs = jnp.where(mask, log_probs, -0.0) + log_probs = log_probs[:, :-1, :] + token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0] + token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1] + return token_log_probs, intermediate_outputs + + def generate_offline_completions(config, tokenizer_model, inference_engine, data): """Generates completions for a batch of prompts using an offline engine. @@ -125,6 +168,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data The input `data` dictionary updated with the generated completions, segmentations, positions, and log-probabilities. """ + # Lazy import: pulls in maxengine and jetstream stubs, which we only want to + # touch when this function is actually called (i.e. during a real GRPO run). + from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel + data[config.train_data_columns] = np.asarray( jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0) ) @@ -175,6 +222,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data return data +def pathways_reshard_nnx( + config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model +): + """`pathways_reshard` for the NNX path. + + Reshard the policy params onto the inference mesh and push them into the + inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet). + """ + if not config.scan_layers: + raise NotImplementedError( + "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." + ) + _, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...) + _, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...) + _, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...) + del source_param_shardings # already encoded on policy_params + with ( + jax.transfer_guard_device_to_host("disallow_explicit"), + jax.transfer_guard_host_to_device("disallow_explicit"), + ): + resharded_params = reshard_pytree(policy_params, dest_param_shardings) + inference_engine.update_params(resharded_params) + + def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): """Reshards model parameters from training to inference sharding. diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..d270a1e23f 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -32,6 +32,7 @@ from jax.experimental.layout import DeviceLocalLayout as DLL # type: ignore from flax import linen as nn +from flax import nnx from flax import struct from flax.linen import partitioning as nn_partitioning import flax @@ -44,8 +45,10 @@ from maxtext.inference.page_manager import PageManager, PageState from maxtext.multimodal import processor as mm_processor from maxtext.utils import lora_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE @@ -111,12 +114,33 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. quant = quantizations.configure_quantization(config) if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # We need both PREFILL and AR abstract models because the cache vars inherit + # CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and + # bulk_insert searches for the substring "cache_batch" in the AR-mode names. + # Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids + # the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm". + _create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL) + _create_model_ar = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model) + abstract_model_ar = nnx.eval_shape(_create_model_ar) + self.model = abstract_model + self.model_ar = abstract_model_ar + # 3-way split so JIT bodies can pass (params, cache, rest) separately to + # nnx.merge. `rest` (RNG state etc.) is materialized in load_params. + graphdef, _, _, _ = nnx.split(abstract_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._create_model_fn = _create_model + self._nnx_rest_state = None else: self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.graphdef = None + self._create_model_fn = None self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -142,6 +166,65 @@ def print_stats(self, label: str): max_utils.print_mem_stats(label) max_utils.print_cpu_ram_stats(label) + # NNX cache adapter: bulk_insert / _insert_jit / _maybe_stack_* switch on + # path[-1].key (e.g. "cached_prefill_key"). NNX state would expose ".value" at + # that position, so we convert NNX state <-> plain dict at the JIT boundary + # via to_pure_dict / replace_by_pure_dict. The cache helpers stay unchanged. + + def _nnx_cache_state_template(self, mode: str = MODEL_MODE_PREFILL) -> Any: + """Empty nnx.State template for the model's nnx.Cache vars (PREFILL=batch 1, AR=batch N).""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + return cache_state + + def _nnx_init_cache_dict(self, mode: str = MODEL_MODE_PREFILL) -> dict: + """Zero-filled pure-dict cache matching the abstract NNX model.""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + cache_dict = cache_state.to_pure_dict() + return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict) + + def _nnx_run_model( + self, + params, + cache_dict, + decoder_input_tokens, + decoder_positions, + *, + decoder_segment_ids=None, + enable_dropout=False, + model_mode, + previous_chunk=None, + true_length=None, + slot=None, + page_state=None, + encoder_images=None, + encoder_image_masks=None, + encoder_audios=None, + ): + """NNX equivalent of `model.apply(..., mutable=["cache"])`. Returns (logits, new_cache_dict).""" + cache_state = self._nnx_cache_state_template(mode=model_mode) + nnx.replace_by_pure_dict(cache_state, cache_dict) + # copy=True avoids reusing Variable objects across traces (TraceContextError), + # mirroring the workaround in train.py's diff_wrapper. + model = nnx.merge(self.graphdef, params, cache_state, self._nnx_rest_state, copy=True) + logits = model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=encoder_images, + encoder_image_masks=encoder_image_masks, + encoder_audios=encoder_audios, + enable_dropout=enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_cache = nnx.state(model, nnx.Cache).to_pure_dict() + return logits, new_cache + def generate_aot( self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None ): # returns (new_decode_state, result_tokens) @@ -225,6 +308,9 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + return self._load_params_nnx(params=params, rng=rng) + if self.model.quant and self.config.checkpoint_is_quantized: print("Loading from the quantized checkpoint...") self.model.quant.quant_mode = quantizations.get_quant_mode("serve") @@ -232,11 +318,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +327,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( @@ -292,10 +370,78 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar return params + def _load_params_nnx(self, params, rng): + """NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.""" + if self.model.quant is not None: + raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.") + + if params: + print("Resharding given NNX params") + _, params_abs, _ = nnx.split(self.model, nnx.Param, ...) + target_shardings = jax.tree.map( + lambda x: x.sharding if hasattr(x, "sharding") else None, + params_abs, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + params_state = jax.device_put(params, target_shardings) + # Build a concrete model once to capture a real `rest` (RNG vars) for nnx.merge. + # Wasteful but simple — the from_pretrained branch below avoids this. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del concrete_model + else: + max_logging.log("Loading NNX params via from_pretrained") + with self._mesh: + nnx_model = model_creation_utils.from_pretrained( + self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + # Refresh graphdef from the concrete loaded model so subsequent merges line up. + graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del nnx_model + + self.abstract_params = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + if isinstance(x, jax.Array) + else None, + params_state, + ) + + self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations_nnx( + self.model, self.config, self._mesh + ) + self.prefill_kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.prefill_kv_cache_annotations, + ) + if self.config.stack_prefill_result_cache: + # With scan_layers=True the NNX cache leaves are already stacked on axis 0, + # so the engine's manual-stack helper (which assumes an unstacked Linen tree) + # doesn't apply. Wiring this up cleanly is a Phase-2 follow-up. + raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.") + # AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL); + # bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes. + self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh) + self.kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, + ) + # state_mesh_annotations is unused on the NNX path; callers reading it + # (e.g. set_engine_vars_from_base_engine) need to be NNX-aware first. + self.state_mesh_annotations = None + + self.print_stats("After load_params (NNX)") + return params_state + def load_single_adapter(self, adapter_path): - """ - Load Single adapter from adapter_path. - Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. + """Load a single LoRA adapter from `adapter_path`. + + Expects `adapter_config.json` plus adapter weights at `/0/items`. + The returned `params` shape matches `self.abstract_params` (NNX or Linen). """ adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -319,19 +465,27 @@ def apply_adapter(self, base_params, adapter_config, adapter_params): lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) def unapply_adapter(self, base_params, adapter_config, adapter_params): """Unapply the adapter params from the merged params to get back the base params.""" lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + quantize_params not yet supported.") self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @@ -486,7 +640,10 @@ def _prefill_jit( if existing_prefix is not None: if not self.use_chunked_prefill: raise ValueError("Using chunked prefill is needed for existing_prefix.") - input_params = params | {"cache": existing_prefix.cache} + # NNX threads existing_prefix.cache via the nnx_cache local below; only + # the Linen path merges cache into input_params (params is a dict there). + if not self.config.pure_nnx: + input_params = params | {"cache": existing_prefix.cache} start_position = existing_prefix.common_prefix_tokens.shape[0] # TODO(yuyanpeng): rename previous_chunk previous_chunk = jnp.expand_dims(existing_prefix.common_prefix_tokens, 0) @@ -518,24 +675,48 @@ def _prefill_jit( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - flat_logits, new_vars = self.model.apply( - input_params, - input_tokens, - positions, - encoder_images=images, - encoder_image_masks=image_masks, - encoder_audios=audio_values, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=MODEL_MODE_PREFILL, - rngs={"params": new_rng}, - mutable=["cache"], - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, + if self.config.pure_nnx: + # Prefill always operates on batch=1 (one padded prompt at a time). + nnx_cache = ( + existing_prefix.cache if existing_prefix is not None else self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL) ) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_cache_dict = self._nnx_run_model( + params=input_params, + cache_dict=nnx_cache, + decoder_input_tokens=input_tokens, + decoder_positions=positions, + decoder_segment_ids=sequence_indicator, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_vars = self.model.apply( + input_params, + input_tokens, + positions, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + rngs={"params": new_rng}, + mutable=["cache"], + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) if return_prompt_logp: prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length) else: @@ -744,6 +925,9 @@ def _prefill_multisampling_jit( prefilling stage. The number of tokens is specified by num_samples. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.") + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) @@ -869,6 +1053,9 @@ def prefill_concat( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.") + if rng is None: rng = jax.random.PRNGKey(0) input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] @@ -1038,17 +1225,30 @@ def _generate_jit( previous_token = decode_state["tokens"] rng, new_rng = jax.random.split(rng) # run one step generation - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - out_logits, new_vars = self.model.apply( - params | {"cache": decode_state["cache"]}, - previous_token, - decode_state["next_pos"], - enable_dropout=False, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs={"params": new_rng}, - mutable=["cache"], - page_state=page_state, - ) + if self.config.pure_nnx: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_cache_dict = self._nnx_run_model( + params=params, + cache_dict=decode_state["cache"], + decoder_input_tokens=previous_token, + decoder_positions=decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_vars = self.model.apply( + params | {"cache": decode_state["cache"]}, + previous_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": new_rng}, + mutable=["cache"], + page_state=page_state, + ) out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) # sampling tokens @@ -1606,6 +1806,9 @@ def init_decode_state( if self.config.attention == "paged" and self.page_manager is not None: page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error + if self.config.pure_nnx: + return self._init_decode_state_nnx(rng=rng, page_state=page_state) + # pylint: disable=unused-argument def init(abstract_params, page_state): x = jnp.ones( @@ -1699,6 +1902,51 @@ def is_lp(k): zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed + def _init_decode_state_nnx(self, rng, page_state) -> DecodeState: + """NNX equivalent of init_decode_state. Returns a decode_state dict with a pure-dict cache.""" + del rng, page_state # cache shape comes from the abstract model + batch = int(self.config.per_device_batch_size * self.mesh.size) + vocab = self.config.vocab_size + + # AR-mode cache so the batch dim matches generate's input shape. + cache_dict_abs = self._nnx_init_cache_dict(mode=MODEL_MODE_AUTOREGRESSIVE) + + @functools.partial(jax.jit, out_shardings=(self.kv_cache_shardings,)) + def _init_cache(): + return (jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict_abs),) + + (cache,) = _init_cache() + + # Per-leaf logical axes for bulk_insert's "cache_batch" lookup. Use model_ar + # so segment_id leaves carry CACHE_BATCH (under PREFILL they'd carry + # CACHE_BATCH_PREFILL, which doesn't contain the "cache_batch" substring). + _, cache_state, _ = nnx.split(self.model_ar, nnx.Cache, ...) + + def _logical_axes_for(var): + # Flax 0.12.6 renamed "sharding" to "out_sharding"; older code may still + # use "sharding_names". Try all three. + meta = var.get_metadata() if hasattr(var, "get_metadata") else {} + out = meta.get("out_sharding") or meta.get("sharding") or meta.get("sharding_names") + if out is None: + return () + return (out,) if isinstance(out, str) else tuple(out) + + annotations_state = jax.tree.map( + _logical_axes_for, + cache_state, + is_leaf=lambda v: isinstance(v, nnx.Variable), + ) + self.kv_cache_annotations_named = annotations_state.to_pure_dict() + + return { + "logits": jnp.zeros((batch, 1, vocab), dtype=jnp.float32), + "cache": cache, + "next_pos": jnp.zeros((batch, 1), dtype=jnp.int32), + "generated_tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "token_logp": jnp.zeros((batch, 1), dtype=jnp.float32), + } + @property def max_concurrent_decodes(self) -> int: """Free slots.""" diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..6c6d12419f 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -545,8 +545,14 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen FP8 ops keep amax_history in mutable Linen scope; jax.checkpoint + # re-traces and hits UnexpectedTracerError. Skip remat for FP8. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -667,7 +673,22 @@ def layer_fn(carry, scanned_vars): params = nnx_ensure_scan_leading_axis(params, length) state = nnx_ensure_scan_leading_axis(state, length) - final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) + # Linen FP8 ops keep amax_history in mutable Linen scope; jax.lax.scan + # leaks the tracer and hits UnexpectedTracerError. Use a Python for-loop + # for FP8 instead. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 7bb532ae7f..d29edd6e8e 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -167,6 +168,31 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Returns True if currently inside a Linen ``init()`` call. + + Used by NNX pipeline modules to short-circuit the scan during init, + where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Resets stale ``_trace_state`` on Variables to unblock downstream ``nnx.split``. + + ``nnx.update`` called with JAX tracer values uses ``_unsafe_bypass_check=True``, + which leaves Variables with a stale ``_trace_state`` from the outer Python + context and breaks ``nnx.split`` with "Cannot extract graph node from different + trace level". Resets ``_trace_state`` on any Variable whose ``_can_update`` is False. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -476,6 +502,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..645eb05e09 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + eps: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index a75cefc291..6fc0e5d2f6 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..61d968de3d 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -398,6 +398,15 @@ def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Computes logits from hidden states; used by vocabulary tiling.""" + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index 09c5b4e079..b743e8d4b7 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -142,6 +143,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -202,7 +204,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 951d10585d..80f01a11aa 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -61,7 +61,7 @@ from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -72,7 +72,7 @@ from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -203,9 +203,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs = intermediates.to_pure_dict() if config.num_vocab_tiling > 1: - raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") - - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -322,10 +323,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params = state.params ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: - if config.use_dpo: - raise NotImplementedError("DPO for NNX modules has not been implemented.") state = nnx.merge(model, state) # reconstruct TrainStateNNX - ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] # --- Gradient computation --- if config.gradient_accumulation_steps > 1: @@ -391,9 +397,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + def diff_wrapper(param, rest, config, data): local_model = nnx.merge(model_graphdef, param, rest, copy=True) - loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) @@ -557,7 +568,9 @@ def move(path, value): if config.use_dpo: new_state = _merge_dpo_state(new_state, reference_params) return new_state, metrics - return nnx.state(new_state), metrics + # Drop Intermediates (e.g. sowed max_logits for QK-Clip) before returning; + # they're absent from state_mesh_shardings and would cause a leaf-count mismatch. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): @@ -577,7 +590,10 @@ def eval_step(model, config, state, data, dropout_rng=None): loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) else: state = nnx.merge(model, state) # reconstruct TrainStateNNX - loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -704,7 +720,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): @@ -748,7 +764,7 @@ def train_loop(config, recorder, state=None): metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 831e97b885..6ba537b94c 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -29,6 +29,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -91,6 +92,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Runs an abstract NNX forward pass to populate `_ACTIVATION_SHARDINGS_DUMP`. + + `get_abstract_state_nnx` only traces `__init__`; activation shardings need + a forward pass to be collected. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -128,7 +150,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -138,10 +161,16 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect NNX activation shardings via an abstract forward pass (must run + # after get_abstract_state, which only traces __init__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -299,7 +328,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..ab9918c4b9 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -90,20 +90,14 @@ def slice_ith(input_layers): def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Input and output are both Linen-format (downstream uses Linen tree paths). + # Route to Linen regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) @@ -114,12 +108,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # LoRA adapters and downstream `_save_decode_checkpoint`/`_possibly_unroll_params` + # are Linen-shaped; use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e1699647c6..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -71,10 +71,16 @@ def _maybe_shard_with_name(inputs, sharding_names): is_nnx = isinstance(model, nnx.Module) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings @@ -105,7 +111,7 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state @@ -156,6 +162,11 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..96f2a5a19e 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -173,19 +173,13 @@ def __init__(self, config: Any, rng: PRNGKeyType): devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config + # Input and output are both Linen-format (uses DeepSeek*ToLinen layers below). + # Route to Linen regardless of pure_nnx. self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 8554d46e3e..4193a613d4 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" + +from collections.abc import Mapping from functools import partial import json import os @@ -36,6 +38,10 @@ from maxtext.utils import sharding from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +# NNX-only imports (`flax.nnx`, `train_state_nnx`, `model_creation_utils`) are +# loaded lazily inside the NNX dispatch branches so the Linen-only flow doesn't +# pull them in. + def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): """ @@ -116,8 +122,10 @@ def unapply_lora_recursively(base_params, lora_params, module_name): def load_adapter(config, base_abstract_state_params, adapter_config_path, adapter_weights_path): - """ - Load the LoRA weights into a PyTree and return it. + """Load LoRA weights into a PyTree and return it. + + On the NNX path, `base_abstract_state_params` and the returned `lora_params` + are `nnx.State`-shaped (no outer `{"params": ...}` wrap). """ # Load LoRA weights lora_params = None @@ -135,7 +143,10 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte if not gcs_utils.gcs_path_exists(f"{adapter_weights_path}/commit_success.txt"): raise FileNotFoundError(f"Failed to read lora_weights from {adapter_weights_path}.") - lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) + if config.pure_nnx: + lora_state, _ = get_lora_abstract_state_nnx(base_abstract_state_params, lora_config) + else: + lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) with nn_partitioning.axis_rules(config.logical_axis_rules): lora_params = checkpointing.load_params_from_path( @@ -150,22 +161,12 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path): - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """Initialize the LoRA train state and optionally load weights from disk. - Args: - model: the flax model to initialize - tx: the optax.GradientTransformation - config: config object - rng: jax.prng key - mesh: jax.devices() mesh - checkpoint_manager: an Orbax checkpointing.CheckpointManager object - lora_adapter_path: Path of the LoRA adapter which is expected to have - `adapter_config.json` and adapter weights - - Returns: - state: the initialized train state - state_mesh_annotations: the mesh annotations for the train state + Returns `(lora_config, lora_state, lora_state_annotations)`. On the NNX path + `model` is unused (the NNX abstract state is built via + `model_creation_utils.create_nnx_abstract_model`) and `lora_state.params` + is `nnx.State`-shaped; on Linen it is the original `{"params": ...}` tree. """ lora_state = None @@ -175,8 +176,18 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # pylint: disable=import-outside-toplevel + from maxtext.layers import train_state_nnx + from maxtext.utils import model_creation_utils + + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) @@ -185,7 +196,11 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp lora_config = gcs_utils.read_json_from_gcs(lora_config_path) - lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) + if config.pure_nnx: + base_abstract_params = _nnx_param_subtree(unboxed_abstract_state) + lora_state, lora_state_annotations = get_lora_abstract_state_nnx(base_abstract_params, lora_config) + else: + lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) lora_weights_path = f"{lora_adapter_path}/0/items" @@ -605,3 +620,165 @@ def _map_to_state(path, variable): nnx.update(trainer.model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") return trainer + + +# NNX-shaped LoRA helpers. The Linen walkers above key on `isinstance(x, dict)` +# and bare leaves; NNX trees use `nnx.State` (Mapping but not dict) and +# Variable-wrapped leaves, so we need separate mirrors. The math (W += B @ A * s) +# is identical. + + +def _is_nnx_branch(x): + return isinstance(x, Mapping) + + +def _nnx_param_subtree(unboxed_abstract_state): + """Drop the outer TrainStateNNX wrapping and return the model substate.""" + return unboxed_abstract_state["model"] if "model" in unboxed_abstract_state else unboxed_abstract_state + + +def apply_lora_on_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX variant of `apply_lora_on_base_params`. Mutates `base_params` in place.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight + jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def unapply_lora_from_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX-shaped variant of `unapply_lora_from_base_params`. Mutates `base_params`.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight - jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def get_lora_abstract_state_nnx(base_abstract_params, lora_config): + """`get_lora_abstract_state` for the NNX path. + + Walks the abstract `state.model` substate and emits a parallel tree with + `lora_a.kernel` / `lora_b.kernel` leaves at target attention paths and + `None` elsewhere. + """ + other_lora_format_to_jax_format = { + "q_proj": "self_attention.query", + "k_proj": "self_attention.key", + "v_proj": "self_attention.value", + "o_proj": "self_attention.out", + } + + lora_target_modules = [other_lora_format_to_jax_format.get(s, s) for s in lora_config["target_modules"]] + lora_rank = int(lora_config["r"]) + + def get_lora_param_shape(base_array_shape, lora_module): + if len(base_array_shape) > 4: + raise ValueError(f"Unsupported base array shape {base_array_shape} (>4D)") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_shape = base_array_shape[:-2] + (lora_rank,) + lora_b_shape = (lora_rank,) + base_array_shape[1:] + elif lora_module == "self_attention.out": + lora_a_shape = base_array_shape[:-1] + (lora_rank,) + if len(base_array_shape) == 4: + lora_b_shape = (lora_rank, base_array_shape[1], base_array_shape[-1]) + else: + lora_b_shape = (lora_rank, base_array_shape[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + return lora_a_shape, lora_b_shape + + def get_lora_param_sharding(base_param_sharding, lora_module): + if base_param_sharding is None: + return None, None + base_pspec = base_param_sharding.spec + if len(base_pspec) > 4: + raise ValueError("PartitionSpec size > 4 not supported") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-2] + ((),))) + lora_b_pspec = jax.sharding.PartitionSpec(*(((),) + base_pspec[1:])) + elif lora_module == "self_attention.out": + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-1] + ((),))) + if len(base_pspec) == 4: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[1], base_pspec[-1]) + else: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + mesh = base_param_sharding.mesh + mem_kind = base_param_sharding.memory_kind + return ( + jax.sharding.NamedSharding(mesh=mesh, spec=lora_a_pspec, memory_kind=mem_kind), + jax.sharding.NamedSharding(mesh=mesh, spec=lora_b_pspec, memory_kind=mem_kind), + ) + + def module_is_target(module_path): + for tgt in lora_target_modules: + if tgt in module_path: + return tgt + return None + + def add_lora(out_node, base_node, path): + for name, child in base_node.items(): + if _is_nnx_branch(child): + out_node[name] = {} + add_lora(out_node[name], child, f"{path}.{name}") + else: + if name not in ("kernel", "scale", "embedding"): + raise ValueError(f"Unexpected key={name} in base abstract params at {path}") + if not isinstance(child, jax.ShapeDtypeStruct): + raise ValueError(f"Unexpected leaf type {type(child).__name__} at {path}.{name}") + target_module = module_is_target(path) + if target_module is not None: + a_shape, b_shape = get_lora_param_shape(child.shape, target_module) + a_sharding, b_sharding = get_lora_param_sharding(child.sharding, target_module) + out_node["lora_a.kernel"] = jax.ShapeDtypeStruct(shape=a_shape, dtype=child.dtype, sharding=a_sharding) + out_node["lora_b.kernel"] = jax.ShapeDtypeStruct(shape=b_shape, dtype=child.dtype, sharding=b_sharding) + else: + out_node[name] = None + + lora_abstract_params = {} + add_lora(lora_abstract_params, base_abstract_params, "") + + unboxed_abstract_lora_state = train_state.TrainState( + step=0, apply_fn=None, params=lora_abstract_params, tx=None, opt_state={} # type: ignore + ) + lora_state_mesh_annotations = train_state.TrainState( + step=0, + apply_fn=None, + params=jax.tree_util.tree_map( + lambda x: x.sharding.spec if x.sharding is not None else None, + lora_abstract_params, + ), + tx=None, # type: ignore + opt_state={}, + ) + return unboxed_abstract_lora_state, lora_state_mesh_annotations diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 0f07f5c14d..6a182cd869 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1724,6 +1724,30 @@ def init_kv_cache(model, config): return state_mesh_annotations +def _nnx_cache_partition_specs(abstract_model, config, mesh): + """Per-leaf PartitionSpec tree for the abstract model's nnx.Cache vars. + + Returned as a pure dict so the engine can wrap it in NamedSharding the same + way it does for the Linen helpers below. + """ + _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) + # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # active flax partitioning context, so wrap. + with nn_partitioning.axis_rules(config.logical_axis_rules): + named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) + + +def get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_prefill_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + +def get_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + def save_quantized_checkpoint_if_configured(config, params): """Save quantized checkpoint if configured""" assert config.quantization, "quantization must be configured" diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3bd2b186b1..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -116,6 +116,7 @@ def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) else: # Linen @@ -154,7 +155,7 @@ def get_leaf_info(leaf): print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -191,6 +192,8 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..893fdc531a 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -52,18 +52,13 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + # Save/restore exerciser uses Linen-shaped optimizer state via + # add_entropy_to_checkpoint(). Route to Linen regardless of pure_nnx. + model = from_config(config) mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. def create_train_state_fn(): model = _create_model_partial() optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(model, optimizer) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) init_state_fn = create_train_state_fn else: @@ -316,8 +322,6 @@ def create_train_state_fn(): maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - if config.pure_nnx: - raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -342,9 +346,17 @@ def create_train_state_fn(): except FileNotFoundError: step0_restored = None if step0_restored is not None: - # TODO: For pure_nnx, the dpo state manipulation is different. - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..4686ff3d82 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -247,3 +247,108 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Computes cross-entropy loss with vocab tiling for NNX models. + + NNX equivalent of ``vocab_tiling_linen_loss``. Scans the vocab dimension + and calls ``model.logits_from_hidden_states`` per chunk. The NNX model + carries its own parameters, so no explicit gather is needed. + + Uses default autograd; a custom_vjp for backward memory savings can be + added later if needed. + + Args: + model: NNX model exposing ``logits_from_hidden_states``. + hidden_states: Final hidden states from the decoder. + data: Dict with ``targets`` and ``targets_segmentation``. + config: Model and training config. + is_train: Whether the model is in training mode. + + Returns: + A tuple ``(total_loss, total_z_loss)``. + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + return total_loss, total_z_loss diff --git a/tests/integration/maxengine_test.py b/tests/integration/maxengine_test.py index eb4a7729d6..0f3329cf35 100644 --- a/tests/integration/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -23,6 +23,8 @@ from jax.sharding import Mesh import numpy as np import pytest +from flax import nnx +from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations @@ -30,7 +32,10 @@ pytest.importorskip("jetstream", reason="jetstream not installed") from maxtext.inference.maxengine import maxengine from maxtext.models import models +from maxtext.checkpoint_conversion import linen_nnx_converter +from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_serving] @@ -163,6 +168,175 @@ def test_basic_decode(self): self.assertEqual(result_token.data.ndim, 2) self.assertEqual(result_token.data.shape[1], 3) + def _init_nnx_pyconfig(self, **kwargs): + """init_pyconfig with NNX flags on.""" + return self.init_pyconfig(pure_nnx=True, enable_nnx=True, pure_nnx_decoder=True, **kwargs) + + def _build_nnx_params(self, cfg, mesh): + """Materialize an NNX Transformer and return its nnx.Param state.""" + _create_model = model_creation_utils.get_nnx_create_model_fn(cfg, mesh=mesh, model_mode=MODEL_MODE_PREFILL) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = _create_model() + _, params_state, _ = nnx.split(model, nnx.Param, ...) + return params_state + + def test_init_nnx(self): + """NNX engine init exposes graphdef + abstract Transformer.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertIsNotNone(engine.graphdef) + self.assertIsNotNone(engine.model) + self.assertEqual(type(engine.model).__name__, "Transformer") + + def test_basic_prefill_nnx(self): + """NNX prefill returns a Linen-shape result dict with finite values.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=input_tokens, true_length=true_length) + + self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) + self.assertEqual(prefill_result["tokens"].size, 1) + self.assertTrue(jnp.array_equal(first_token.data.size, 3)) + self.assertEqual(first_token.log_prob.shape, (1, 1)) + self.assertIn("cache", prefill_result) + self.assertIsInstance(prefill_result["cache"], dict) + # Catch silent NaN/inf from a bad nnx.merge or cache round-trip. + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + cache_leaves, _ = jax.tree.flatten(prefill_result["cache"]) + for leaf in cache_leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}") + # scan_layers=True (default in test config) ⇒ leading axis is num_decoder_layers. + for leaf in cache_leaves: + self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}") + + def test_basic_decode_nnx(self): + """NNX prefill → insert → 4 generate steps. Verifies next_pos advances and logits stay finite.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304]) + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + decode_state = engine.init_decode_state() + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + decode_state = engine.insert(prefill_result, decode_state, slot=0) + + # 4 steps is enough to catch off-by-one cache pointer bugs. + initial_next_pos = int(decode_state["next_pos"][0, 0]) + for step in range(4): + decode_state, result_token = engine.generate(params=params, decode_state=decode_state) + self.assertEqual(result_token.log_prob.ndim, 2) + self.assertEqual(result_token.log_prob.shape[1], 1) + self.assertEqual(result_token.data.ndim, 2) + self.assertEqual(result_token.data.shape[1], 3) + self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"]))) + self.assertEqual( + int(decode_state["next_pos"][0, 0]), + initial_next_pos + step + 1, + msg=f"next_pos didn't advance at step {step}", + ) + + def test_quantize_raises_for_nnx(self): + """pure_nnx + quantization raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig(quantization="int8") + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_params(rng=self.rng) + + def test_lora_load_single_adapter_reaches_loader_on_nnx(self): + """pure_nnx + LoRA: load_single_adapter dispatches to the NNX loader. + + With a nonexistent path the loader raises FileNotFoundError (not + NotImplementedError, which would mean the dispatch never reached the loader). + """ + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(FileNotFoundError): + engine.load_single_adapter("/nonexistent/adapter/path") + + def _linen_params_to_nnx_state(self, linen_params, abstract_nnx_model): + """Convert Linen params → NNX nnx.Param state via linen_nnx_converter so both engines share weights.""" + nnx_dict_wrapped = linen_nnx_converter.convert_linen_to_nnx({"params": linen_params}, scan_layers=True)["model"] + # pylint: disable=protected-access + nnx_pure = linen_nnx_converter._strip_value_wrappers(nnx_dict_wrapped) + _, params_state, _ = nnx.split(abstract_nnx_model, nnx.Param, ...) + nnx.replace_by_pure_dict(params_state, nnx_pure) + return params_state + + def test_linen_nnx_parity_prefill(self): + """Same weights → same prefill output across Linen and NNX engines. + + A failure here means the NNX forward pass diverges from Linen on identical + weights (cache plumbing, nnx.merge wiring, or Transformer.__call__). + """ + cfg_linen = self.init_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + # Linen: init params, run prefill. + quant = quantizations.configure_quantization(cfg_linen) + linen_model = models.transformer_as_linen(config=cfg_linen, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + ids, decoder_segment_ids, decoder_positions = self.get_data() + linen_vars = linen_model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + ) + # Linen.init wraps leaves in LogicallyPartitioned (which has a `.value` + # attribute); unbox so the converter's {value:} wrapper detector doesn't + # mistake them for already-wrapped NNX leaves. + linen_vars = max_utils.unbox_logicallypartioned(linen_vars) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + linen_engine = maxengine.MaxEngine(cfg_linen, jax.devices()) + linen_params = linen_engine.load_params(params=linen_vars) + linen_prefill, linen_first_token = linen_engine.prefill( + params=linen_params, padded_tokens=input_tokens, true_length=true_length + ) + + # NNX: bridge Linen weights, run prefill on the same prompt. + cfg_nnx = self._init_nnx_pyconfig() + nnx_engine = maxengine.MaxEngine(cfg_nnx, jax.devices()) + nnx_params_state = self._linen_params_to_nnx_state(linen_vars["params"], nnx_engine.model) + nnx_params = nnx_engine.load_params(params=nnx_params_state) + nnx_prefill, nnx_first_token = nnx_engine.prefill( + params=nnx_params, padded_tokens=input_tokens, true_length=true_length + ) + + # Tolerance is loose because the test config uses bf16 compute, where + # accumulation order between Linen-scan and NNX-scan drifts by ~0.05. + # Greedy match below is the behavioral check that actually matters. + linen_logits = np.asarray(linen_prefill["logits"]) + nnx_logits = np.asarray(nnx_prefill["logits"]) + self.assertEqual(linen_logits.shape, nnx_logits.shape) + np.testing.assert_allclose( + linen_logits, + nnx_logits, + rtol=0.05, + atol=0.1, + err_msg="Linen vs NNX prefill logits diverge beyond bf16 tolerance.", + ) + self.assertEqual( + int(linen_first_token.data[0, 0]), + int(nnx_first_token.data[0, 0]), + msg="Linen and NNX disagreed on greedy first token with identical weights.", + ) + linen_cache_leaves, _ = jax.tree.flatten(linen_prefill["cache"]) + nnx_cache_leaves, _ = jax.tree.flatten(nnx_prefill["cache"]) + self.assertEqual(len(linen_cache_leaves), len(nnx_cache_leaves)) + @pytest.mark.skip(reason="Can only pass on CPU.") def test_chunked_prefill(self): """Test identical result between chunked prefill with single and multiple chunked. diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index d11f9658a7..05a7fcffec 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -126,15 +126,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): del model - def test_pure_nnx_dpo_raises_not_implemented(self): - """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" - # use_dpo requires a few prerequisites; the simplest is to set the flag and - # let setup_train_loop reach the NotImplementedError check before the more - # involved DPO path runs. - config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) - with self.assertRaises(NotImplementedError): - setup_train_loop(config, recorder=None) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/grpo_nnx_test.py b/tests/unit/grpo_nnx_test.py new file mode 100644 index 0000000000..6f72c43723 --- /dev/null +++ b/tests/unit/grpo_nnx_test.py @@ -0,0 +1,231 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for `grpo_loss_fn_nnx`, `compute_log_probs_nnx`, plus a small +Linen-path regression block (the repo's existing Linen GRPO integration test +is TPU-only).""" + +import types +import unittest + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx + +from maxtext.experimental.rl import grpo_trainer +from maxtext.experimental.rl import grpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX module that responds to the kwargs `compute_log_probs_nnx` uses.""" + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_grpo_config(**overrides): + """Minimal config namespace covering every field `grpo_loss_fn_nnx` reads.""" + base = { + "train_data_columns": "prompt", + "num_generations": 2, + "grpo_epsilon": 0.2, + "grpo_beta": 0.1, + "num_experts": 1, + "decode_sampling_temperature": 1.0, + "enable_dropout": False, + "use_dpo": False, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_grpo_batch(B=2, G=2, S=6): + """Minimal GRPO batch: `B` prompts, `G` generations each (total `B*G`), seq length `S`.""" + total = B * G + prompts = jnp.tile(jnp.arange(S, dtype=jnp.int32), (total, 1)) + return { + "prompt_completions": prompts, + "prompt_completions_position": prompts, + "prompt_completions_segmentation": jnp.ones((total, S), dtype=jnp.int32), + "ar_completions_segmentation": jnp.array([[0, 0, 1, 1, 1, 0]] * total, dtype=jnp.int32), + "completions_logprobs": None, # off-policy + } + + +class TestGrpoLossFnNnx(unittest.TestCase): + """Behavior of `grpo_loss_fn_nnx` on a synthetic policy / reference pair.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) # identical seed + self.config = _make_grpo_config() + self.data = _make_grpo_batch() + + def test_aux_structure_matches_linen(self): + """`grpo_loss_fn_nnx` returns the same `LossAux` dataclass shape as `grpo_loss_fn`.""" + loss, aux = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, None, None, self.reference, is_train=True + ) + self.assertIsInstance(aux, grpo_trainer.LossAux) + for field in ( + "total_loss", + "avg_reward", + "avg_reward_std", + "avg_advantage", + "completion_length", + "moe_lb_loss", + "total_weights", + ): + self.assertTrue(hasattr(aux, field), f"aux missing field {field}") + self.assertTrue(jnp.isfinite(loss)) + + def test_unused_dropout_rng_and_params_args_are_ignored(self): + """`dropout_rng` and `params` are positional placeholders only — values shouldn't matter.""" + a = grpo_trainer.grpo_loss_fn_nnx(self.policy, self.config, self.data, None, None, self.reference, is_train=True) + b = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, jax.random.key(99), {"junk": 1}, self.reference, is_train=True + ) + np.testing.assert_allclose(np.asarray(a[0]), np.asarray(b[0]), rtol=1e-6) + + def test_identical_policy_and_reference_zero_kl(self): + """Identical policy and reference → per-token KL is zero, so `aux.avg_kl ≈ 0`.""" + cfg = _make_grpo_config(grpo_beta=0.5) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNotNone(aux.avg_kl) + np.testing.assert_allclose(np.asarray(aux.avg_kl), 0.0, atol=1e-5) + + def test_grpo_beta_zero_avg_kl_is_none(self): + cfg = _make_grpo_config(grpo_beta=0.0) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNone(aux.avg_kl) + + def test_value_and_grad_flows_only_to_policy(self): + """`nnx.value_and_grad` over the policy yields finite grads; reference is left alone.""" + + def loss_only(policy_model): + loss, _ = grpo_trainer.grpo_loss_fn_nnx( + policy_model, self.config, self.data, None, None, self.reference, is_train=True + ) + return loss + + # nnx.value_and_grad returns (value, grad_state) where grad_state holds nnx.Param leaves. + _, grads = nnx.value_and_grad(loss_only, argnums=0)(self.policy) + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(np.all(np.isfinite(np.asarray(leaf))), "policy grad has non-finite entries") + + +class TestComputeLogProbsNnx(unittest.TestCase): + """Shape contract of `compute_log_probs_nnx`.""" + + def test_returns_correct_shape(self): + config = _make_grpo_config() + data = _make_grpo_batch() + model = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + log_probs, _ = grpo_utils.compute_log_probs_nnx( + model, + data["prompt_completions"], + data["prompt_completions_position"], + data["prompt_completions_segmentation"], + data["ar_completions_segmentation"], + config, + is_train=False, + ) + # Inputs are [B, S] → log_probs are [B, S-1]. + self.assertEqual(log_probs.shape, (data["prompt_completions"].shape[0], data["prompt_completions"].shape[1] - 1)) + + +# --------------------------------------------------------------------------- +# Linen-path regression smoke tests +# --------------------------------------------------------------------------- + + +class _MockLinenTransformer(nn.Module): + """Tiny Linen module that responds to the same `model.apply(...)` shape Linen `compute_log_probs` uses.""" + + vocab_size: int + embed_dim: int + + @nn.compact + def __call__(self, inputs, positions, decoder_segment_ids=None, enable_dropout=False): + del positions, decoder_segment_ids, enable_dropout + embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim, name="embed")(inputs) + return nn.Dense(features=self.vocab_size, name="proj")(embed) + + +class TestLinenGrpoRegression(unittest.TestCase): + """Smoke test that the Linen `grpo_loss_fn` and `compute_log_probs` still run + end-to-end with `pure_nnx=False`-style inputs.""" + + def setUp(self): + self.config = _make_grpo_config() + self.config.pure_nnx = False # explicit Linen mode + self.config.gradient_accumulation_steps = 1 + self.data = _make_grpo_batch() + self.model = _MockLinenTransformer(vocab_size=8, embed_dim=4) + rng = jax.random.key(0) + inputs = self.data["prompt_completions"] + self.params = self.model.init(rng, inputs, inputs, decoder_segment_ids=jnp.ones_like(inputs), enable_dropout=False) + self.reference_params = jax.tree_util.tree_map(jnp.copy, self.params) + + def test_linen_grpo_loss_fn_still_runs(self): + """Linen `grpo_loss_fn` returns a finite loss + a `LossAux`.""" + loss, aux = grpo_trainer.grpo_loss_fn( + self.model, + self.config, + self.data, + jax.random.key(1), + self.params, + self.reference_params["params"], # Linen reference_params is the inner subtree + is_train=True, + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertTrue(hasattr(aux, "total_loss")) + self.assertTrue(hasattr(aux, "moe_lb_loss")) + self.assertTrue(hasattr(aux, "total_weights")) + + def test_linen_compute_log_probs_still_runs(self): + """Linen `compute_log_probs` produces shape `[B, S-1]`.""" + log_probs, _ = grpo_utils.compute_log_probs( + self.model, + self.params, + self.data["prompt_completions"], + self.data["prompt_completions_position"], + self.data["prompt_completions_segmentation"], + self.data["ar_completions_segmentation"], + self.config, + is_train=False, + rngs={"dropout": jax.random.key(2), "params": jax.random.key(3)}, + ) + S = self.data["prompt_completions"].shape[1] + self.assertEqual(log_probs.shape, (self.data["prompt_completions"].shape[0], S - 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/lora_utils_nnx_test.py b/tests/unit/lora_utils_nnx_test.py new file mode 100644 index 0000000000..e0e8cbb529 --- /dev/null +++ b/tests/unit/lora_utils_nnx_test.py @@ -0,0 +1,293 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-shaped LoRA helpers in `lora_utils`, plus a small +Linen regression block.""" + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +from maxtext.utils.lora_utils import ( + apply_lora_on_base_params, + apply_lora_on_base_params_nnx, + get_lora_abstract_state_nnx, + unapply_lora_from_base_params, + unapply_lora_from_base_params_nnx, +) + + +# --------------------------------------------------------------------------- +# Fake abstract state builders (mirror the NNX vs. Linen tree shapes) +# --------------------------------------------------------------------------- + + +def _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Tiny NNX-shaped abstract state for one attention block.""" + + def _sds(shape): + return jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=None) + + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": _sds((emb, num_heads, head_dim))}, + "key": {"kernel": _sds((emb, num_heads, head_dim))}, + "value": {"kernel": _sds((emb, num_heads, head_dim))}, + "out": {"kernel": _sds((emb, num_heads, head_dim))}, + }, + "mlp": {"wi": {"kernel": _sds((emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": _sds((100, emb))}, + }, + } + + +def _make_linen_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Linen-shaped equivalent (with the `{"params": ...}` outer wrap).""" + return {"params": _make_nnx_attention_abstract(emb, num_heads, head_dim, dtype)} + + +def _lora_config(rank=4, alpha=8.0, target_modules=("q_proj", "v_proj")): + return { + "r": rank, + "lora_alpha": alpha, + "target_modules": list(target_modules), + } + + +# --------------------------------------------------------------------------- +# get_lora_abstract_state_nnx +# --------------------------------------------------------------------------- + + +class TestGetLoraAbstractStateNnx(unittest.TestCase): + """`get_lora_abstract_state_nnx` shape, sharding, and error-path coverage.""" + + def test_lora_shapes_for_query_and_value(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=4)) + attn = state.params["decoder"]["layers"]["self_attention"] + + a = attn["query"]["lora_a.kernel"] + b = attn["query"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + self.assertEqual(a.dtype, jnp.float32) + self.assertEqual(b.dtype, jnp.float32) + + a = attn["value"]["lora_a.kernel"] + b = attn["value"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + + def test_non_target_modules_emit_none_leaves(self): + abs_params = _make_nnx_attention_abstract() + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(target_modules=("q_proj",))) + attn = state.params["decoder"]["layers"]["self_attention"] + self.assertIn("lora_a.kernel", attn["query"]) + self.assertIsNone(attn["key"]["kernel"]) + self.assertIsNone(attn["value"]["kernel"]) + self.assertIsNone(attn["out"]["kernel"]) + self.assertIsNone(state.params["decoder"]["layers"]["mlp"]["wi"]["kernel"]) + self.assertIsNone(state.params["decoder"]["shared_embedding"]["embedding"]) + + def test_o_proj_has_distinct_shape(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=3, target_modules=("o_proj",))) + out = state.params["decoder"]["layers"]["self_attention"]["out"] + a = out["lora_a.kernel"] + b = out["lora_b.kernel"] + # 3D base (emb, num_heads, head_dim) → lora_a.shape = (..., r), lora_b = (r, last) + self.assertEqual(a.shape, (8, 2, 3)) + self.assertEqual(b.shape, (3, 4)) + + def test_unsupported_leaf_type_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"kernel": jnp.zeros((4, 2, 2))}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + def test_unexpected_leaf_name_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"weight": jax.ShapeDtypeStruct((4, 2), jnp.float32)}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + # Linen-vs-NNX numerical parity is covered by TestApplyLoraNnx.test_numerical_parity_with_linen_apply. + + +# --------------------------------------------------------------------------- +# apply / unapply on NNX-shape pure dicts +# --------------------------------------------------------------------------- + + +def _concrete_base(rng=None, emb=4, num_heads=2, head_dim=3): + """Concrete arrays mirroring the abstract structure used above (NNX-shape).""" + if rng is None: + rng = jax.random.key(0) + k1, k2, k3, k4, k5, k6 = jax.random.split(rng, 6) + shape_attn = (emb, num_heads, head_dim) + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": jax.random.normal(k1, shape_attn)}, + "key": {"kernel": jax.random.normal(k2, shape_attn)}, + "value": {"kernel": jax.random.normal(k3, shape_attn)}, + "out": {"kernel": jax.random.normal(k4, shape_attn)}, + }, + "mlp": {"wi": {"kernel": jax.random.normal(k5, (emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": jax.random.normal(k6, (100, emb))}, + }, + } + + +def _build_lora_params(base, lora_config_dict, rng): + """Build a concrete LoRA tree (random arrays) matching `base`.""" + abs_tree = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=None), base) + lora_state, _ = get_lora_abstract_state_nnx(abs_tree, lora_config_dict) + + def _to_concrete(leaf, rng_key): + if leaf is None: + return None + return jax.random.normal(rng_key, leaf.shape, leaf.dtype) + + leaves, tree = jax.tree_util.tree_flatten(lora_state.params, is_leaf=lambda x: x is None) + rngs = jax.random.split(rng, max(1, len(leaves))) + out_leaves = [_to_concrete(l, r) for l, r in zip(leaves, rngs)] + return jax.tree_util.tree_unflatten(tree, out_leaves) + + +class TestApplyLoraNnx(unittest.TestCase): + """`apply_lora_on_base_params_nnx` round-trip and Linen-vs-NNX parity.""" + + def test_apply_then_unapply_is_identity(self): + rng = jax.random.key(42) + base_orig = _concrete_base(rng) + base = jax.tree_util.tree_map(jnp.copy, base_orig) + lora = _build_lora_params(base, _lora_config(rank=2, target_modules=("q_proj", "v_proj")), jax.random.key(7)) + apply_lora_on_base_params_nnx(base, lora, lora_scale_factor=0.5) + # query/value kernels were modified + self.assertFalse( + jnp.allclose( + base["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key/out are untouched + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + unapply_lora_from_base_params_nnx(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_numerical_parity_with_linen_apply(self): + """Same base+lora numbers → same kernel after apply, on either tree shape.""" + rng = jax.random.key(123) + base_nnx = _concrete_base(rng) + base_linen = {"params": jax.tree_util.tree_map(jnp.copy, base_nnx)} + lora = _build_lora_params(base_nnx, _lora_config(rank=2, target_modules=("q_proj",)), jax.random.key(5)) + apply_lora_on_base_params_nnx(base_nnx, lora, lora_scale_factor=0.7) + apply_lora_on_base_params(base_linen, {"params": lora}, lora_scale_factor=0.7) + np.testing.assert_allclose( + np.asarray(base_nnx["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_linen["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-6, + ) + + def test_apply_with_unexpected_lora_key_raises(self): + base = _concrete_base() + bad = {"decoder": {"layers": {"self_attention": {"query": {"unexpected": jnp.zeros((4, 2))}}}}} + with self.assertRaises(ValueError): + apply_lora_on_base_params_nnx(base, bad) + + +class TestLinenLoraRegression(unittest.TestCase): + """Smoke tests for the Linen apply / unapply helpers (no other unit test exercises them).""" + + def _linen_pair(self, rng=None): + """Build a Linen-shape (with `{"params": ...}` outer wrapper) base + lora pair.""" + if rng is None: + rng = jax.random.key(99) + base_inner = _concrete_base(rng) + base = {"params": jax.tree_util.tree_map(jnp.copy, base_inner)} + lora_inner = _build_lora_params( + base_inner, + _lora_config(rank=2, target_modules=("q_proj", "v_proj")), + jax.random.key(7), + ) + lora = {"params": lora_inner} + return base, lora + + def test_linen_apply_then_unapply_is_identity(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=0.5) + unapply_lora_from_base_params(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_linen_apply_only_modifies_target_modules(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=1.0) + # query and value are targets — must change. + self.assertFalse( + jnp.allclose( + base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key and out are non-target — must be untouched. + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 3495b4c557..4340d4e22a 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -154,13 +154,6 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) - def test_vocab_tiling_raises_not_implemented(self): - cfg, ts = _build_state() - cfg.num_vocab_tiling = 4 - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) - class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" @@ -181,16 +174,6 @@ def test_train_step_returns_state_and_metrics(self): self.assertIn("learning/param_norm", metrics["scalar"]) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) - def test_train_step_dpo_raises_for_nnx(self): - cfg, ts = _build_state() - cfg.use_dpo = True - state_graphdef, state_pure = nnx.split(ts) - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.train_step( - state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data - ) - def test_train_step_increments_optimizer_step(self): cfg, ts = _build_state() state_graphdef, state_pure = nnx.split(ts)