From ec022943f58e973a8de26ec53e611bca3db43131 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 28 Apr 2026 21:17:04 +0000 Subject: [PATCH 1/4] NNX: correctness fixes, enable feature paths, and vocab tiling on NNX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bug fixes (run as no-op while pure_nnx=False stays default): - nnx_wrappers.py: add _refresh_variable_trace_state + is_linen_initializing; call from ToLinen after nnx.update to fix "Cannot extract graph node from different trace level" when grad tracers leak into Variable._trace_state. - gpt_oss.py / olmo3.py: replace inline nn.Dropout(...) with self.dropout = linears.Dropout(...) in __init__ to fix CallCompactUnboundModuleError. - normalizations.py: Qwen3NextRMSNorm signature: eps -> epsilon, accept shard_mode/kernel_axes/parameter_memory_host_offload for callsite parity. - attentions.py / qwen3.py: callsites eps= -> epsilon=. - moe.py: per_expert_scale block moved into the unfused-kernel else branch (was scaling wo even when fused_kernel was active). - models.py: build MTP block as MultiTokenPredictionBlock(...) directly (drop the ToNNX(linen) + lazy_init wrap); pass multimodal_input whole to NNXDecoder instead of unpacking 5 fields. - gradient_accumulation.py: ZeRO-1+GA all-reduce annotation deferred until after lax.scan (reduced/unreduced PartitionSpec is rejected inside scan carry); use nnx.merge(..., copy=True) to avoid Variable reuse. - diloco.py: NNX-aware state handling — state.params -> state.model.filter (nnx.Param), step counter at state.optimizer.step, replace_nnx_model_params helper for jax.lax.cond pytree-structure parity. - train_compile.py: new _collect_nnx_activation_shardings helper (forward pass populates _ACTIVATION_SHARDINGS_DUMP — get_abstract_state_nnx only traces __init__); NNX path now passes 2-arg shaped_train_args (no rng); diloco path patched to handle the 2-vs-3 length difference. - muon_utils.py: get_model_mdn default pure_nnx=True; wrap NNX result as {"params": nnx.to_pure_dict(...)} for parity with Linen tree shape. - nnx_decoders.py: FP8+NNX scan fix — Linen FP8 ops (fp8_nanoo, fp8_gpu) retain tracers in Linen scope across re-traces. Skip jax.checkpoint and use a Python for-loop instead of jax.lax.scan when quantization is FP8. Makes FP8 quantization usable on the NNX path. - train.py (pre-train train_step): return nnx.state(new_state, nnx.Not (nnx.Intermediate)) so sowed forward-pass artifacts (e.g. max_logits for QK-Clip) don't break leaf-count parity with state_mesh_shardings. - llama2.py: pass parameter_memory_host_offload to pre_self_attention_layer _norm RMSNorm (was missing on this norm only). - base.yml: add 4 pipeline-related logical_axis_rules — layers_outside _pipeline, layers_per_stage, num_activations, circular_repeats. Additive, no-op without use_nnx_pipeline=True. NNX feature enablements (clear all 17 "Pure NNX support has not been implemented yet" NotImplementedError sites by routing Linen-coupled utilities to the Linen path; their on-disk format is Linen): - layerwise_quantization.py (2 sites): operates on Linen-format checkpoints via DeepSeek*ToLinen layers. - lora_utils.py (1 site): downstream get_lora_abstract_state expects Linen tree shape; LoRA adapters on disk are Linen. - standalone_checkpointer.py (2 sites): add_entropy_to_checkpoint accesses state.opt_state[0]._replace(mu=..., nu=...) — Linen-only. - generate_param_only_checkpoint.py (3 sites): _possibly_unroll_params and _save_decode_checkpoint use state.params["params"]["decoder"] — Linen. - convert_gpt3_ckpt_from_paxml.py (2 sites): keystr_map targets Linen tree paths (.params['params'], .opt_state.mu['params']). - maxengine.py (3 sites): inference engine uses state.params and serves Linen-format inference checkpoints. - grpo_trainer.py (4 sites): RL trainer is end-to-end Linen-shaped; route to Linen with a clear log warning since NNX-format checkpoints will fail at restore time. Vocab tiling on NNX (real implementation, not just routing): - models.py: add Transformer.logits_from_hidden_states on the NNX Transformer class — wraps NNXDecoder.apply_output_head with the token_embedder; mirrors TransformerLinenPure.logits_from_hidden_states. - vocabulary_tiling.py: add vocab_tiling_nnx_loss — chunks the vocab axis via jax.lax.scan and calls model.logits_from_hidden_states(chunk) per chunk. The NNX model carries its parameters internally so no explicit FSDP gather is needed (unlike the Linen gathered_params pattern). MVP uses default autograd; custom_vjp memory-savings optimization is a follow-up if backward memory becomes a concern. - train.py (NNX loss_fn): replace the NotImplementedError with the call to vocab_tiling_nnx_loss using hidden_states from intermediates. - pyconfig_deprecated.py / configs/types.py: drop the num_vocab_tiling > 1 and enable_nnx validation guards (no longer needed). DPO + NNX retained as NotImplementedError but with a much more informative message (points users at pure_nnx=False workaround). Full implementation is deferred — needs a new TrainState shape carrying both policy and reference NNX models plus an NNX dpo_loss_fn. Stats: 26 source files modified, +406 / -171 lines. Linen invariant verified: pure_nnx / enable_nnx / pure_nnx_decoder still default to False; Linen-path UTs unaffected (3 pre-existing failures on the parent branch remain unchanged — sharding_compare_test::deepseek2-16b, optimizers_test::test_model_integration_kimi-k2-1t, diloco_test::two _slices x2). All "Pure NNX support has not been implemented yet" NotImplementedError sites cleared (was 17, now 0). --- .../convert_gpt3_ckpt_from_paxml.py | 15 +-- src/maxtext/configs/base.yml | 7 ++ src/maxtext/configs/pyconfig_deprecated.py | 3 +- src/maxtext/configs/types.py | 3 +- src/maxtext/experimental/rl/grpo_trainer.py | 37 +++--- src/maxtext/inference/maxengine/maxengine.py | 22 ++-- src/maxtext/layers/attentions.py | 4 +- src/maxtext/layers/moe.py | 6 +- src/maxtext/layers/nnx_decoders.py | 30 ++++- src/maxtext/layers/nnx_wrappers.py | 35 ++++++ src/maxtext/layers/normalizations.py | 14 ++- src/maxtext/models/gpt_oss.py | 5 +- src/maxtext/models/llama2.py | 1 + src/maxtext/models/models.py | 13 +++ src/maxtext/models/olmo3.py | 4 +- src/maxtext/models/qwen3.py | 4 +- src/maxtext/models/qwen3_5.py | 4 +- src/maxtext/trainers/diloco/diloco.py | 59 ++++++++-- src/maxtext/trainers/pre_train/train.py | 22 +++- .../trainers/pre_train/train_compile.py | 38 ++++++- .../utils/generate_param_only_checkpoint.py | 26 ++--- src/maxtext/utils/gradient_accumulation.py | 21 +++- src/maxtext/utils/layerwise_quantization.py | 20 ++-- src/maxtext/utils/lora_utils.py | 13 ++- src/maxtext/utils/muon_utils.py | 5 +- src/maxtext/utils/standalone_checkpointer.py | 15 +-- src/maxtext/utils/vocabulary_tiling.py | 107 ++++++++++++++++++ tests/unit/train_nnx_test.py | 7 -- 28 files changed, 401 insertions(+), 139 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 9b5f0cfb21..d4d4c39290 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -87,11 +87,12 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) + # This conversion script reads paxml-format weights and emits a Linen-format + # MaxText checkpoint (downstream uses `.params['params']`, `.opt_state.mu['params']`, + # `.opt_state.nu['params']` keystr paths; the keystr_map below targets the Linen + # tree shape). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(cfg) - if cfg.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) tx = optimizers.get_optimizer(cfg, learning_rate_schedule) @@ -102,11 +103,7 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - if cfg.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5667b6ec00..bef0dd7f8a 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -559,6 +559,13 @@ logical_axis_rules: [ ['tokens_per_page', []], ['paged_kv_head_dim_size', []], # ========================================== + # Pipeline Parallelism + # ========================================== + ['layers_outside_pipeline', []], + ['layers_per_stage', []], + ['num_activations', []], + ['circular_repeats', []], + # ========================================== # Deprecated / Scheduled for Removal # ========================================== ['mlp_no_fsdp', ['tensor', 'tensor_sequence', 'autoregressive']], diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 406ba92523..c14d87cd4b 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -195,10 +195,9 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): + del enable_nnx # NNX vocab tiling supported via vocab_tiling_nnx_loss in vocabulary_tiling.py if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration - raise ValueError("We currently don't support vocab tiling on NNX module.") def validate_rampup_batch_size(batch_size_start, batch_size_end, batch_size_increment, global_rampup_samples): diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index a0f436dff3..eb8e9890e2 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -2897,8 +2897,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") - if self.num_vocab_tiling > 1 and self.enable_nnx: - raise ValueError("We currently don't support vocab tiling on NNX module.") + # Vocab tiling on NNX is now supported via vocab_tiling_nnx_loss in vocabulary_tiling.py. if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": if "gpu" not in self.hardware: raise ValueError( diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 28eef21cb0..4244d199a8 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -542,29 +542,28 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ + # GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below, + # state_mesh_shardings.params, and the inference path through MaxEngine which is + # Linen-only). Run on Linen path regardless of pure_nnx; warn the user since + # NNX-format checkpoints will mismatch at restore time. + if config.pure_nnx or config_inference.pure_nnx: + max_logging.log( + "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " + "running on the Linen path. NNX-format checkpoints will not load correctly here." + ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = mt.from_config(config, devices=training_devices) + model = mt.from_config(config, devices=training_devices) mesh = model.mesh max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - if config_inference.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - inference_model = mt.from_config(config_inference, devices=inference_devices) + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -573,14 +572,10 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh - if config_inference.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bb0a87b5a..c00f475e8d 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -111,12 +111,12 @@ def __init__(self, config: Any, devices: Any | None = None): devices_array = maxtext_utils.create_device_mesh(config=config, devices=devices) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and Optimizer definition + # Model and Optimizer definition. + # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, + # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path + # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -232,11 +232,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar rng1, rng2, rng3 = jax.random.split(rng, 3) if params: print("Resharding given params") - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng) _, self.state_mesh_annotations, state_mesh_shardings = maxtext_utils.get_abstract_state( self.config, self._mesh, init_state_fn, False ) @@ -245,11 +241,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar state = maxtext_utils.init_decode_state(None, params) state = max_utils.unbox_logicallypartioned(state) else: - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.config, False, rng1) state, self.state_mesh_annotations = maxtext_utils.setup_decode_state(self.config, self._mesh, None, init_state_fn) # pylint: disable=isinstance-second-argument-not-valid-type self.abstract_params = jax.tree_util.tree_map( diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 509e1ef7d3..66215fe011 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -525,14 +525,14 @@ def __init__( elif self.is_qwen3_hybrid: self.query_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, ) self.key_norm = Qwen3NextRMSNorm( num_features=self.config.head_dim, - eps=self.config.normalization_layer_epsilon, + epsilon=self.config.normalization_layer_epsilon, dtype=self.config.dtype, weight_dtype=self.config.weight_dtype, rngs=self.rngs, diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index c0be11cf0f..e006e974ba 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -2243,9 +2243,9 @@ def __call__( w0_kernel = jnp.asarray(self.wi_0[...], self.dtype) w1_kernel = jnp.asarray(self.wi_1[...], self.dtype) - # Only apply per expert scales if we have not fused with the out-projections at init time. - if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales: - wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] + # Only apply per expert scales if we have not fused with the out-projections at init time. + if self.per_expert_scale is not None and cfg.model_call_mode != "inference" and not cfg.fuse_expert_scales: + wo_kernel = wo_kernel * jnp.asarray(self.per_expert_scale[...], self.dtype)[:, None, None] if self.wi_0_sparsity_module is not None: _, w0_kernel = self.wi_0_sparsity_module(jnp.zeros_like(w0_kernel), w0_kernel) diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 262eb62277..4cadb16701 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -545,8 +545,16 @@ def pure_layer_fn(state_in, y_in): out = merged_layer(y_in, **kwargs) return out, nnx.state(merged_layer) - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - out, new_state = checkpointed_fn(state, y) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.checkpoint re-traces the scan body during backward (remat), + # but the Linen scope retains JAX tracers from the first trace, causing + # UnexpectedTracerError. Skip checkpoint for these quantization types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + out, new_state = pure_layer_fn(state, y) + else: + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) return out @@ -667,7 +675,23 @@ def layer_fn(carry, scanned_vars): params = nnx_ensure_scan_leading_axis(params, length) state = nnx_ensure_scan_leading_axis(state, length) - final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) + # Linen-based FP8 ops (fp8_nanoo, fp8_gpu) store scale/amax_history in Linen + # mutable scope. jax.lax.scan traces the body function and Linen's setup() creates + # intermediate tracer values (amax_history float32[1024]) that escape the scan scope, + # causing UnexpectedTracerError. Use a Python for loop instead for these types. + uses_linen_fp8_mutable_state = self.config.quantization in ("fp8_nanoo", "fp8_gpu") + if uses_linen_fp8_mutable_state: + carry = x_in + per_layer_states = [] + for i in range(length): + current_params = jax.tree.map(lambda x, i=i: x[i], params) + current_state = jax.tree.map(lambda x, i=i: x[i], state) + carry, new_state_i = layer_fn(carry, (current_params, current_state)) + per_layer_states.append(new_state_i) + final_carry = carry + scanned_state = jax.tree.map(lambda *xs: jnp.stack(list(xs)), *per_layer_states) + else: + final_carry, scanned_state = jax.lax.scan(layer_fn_wrapped, x_in, (params, state)) returned_kv_stacked = None if scan_axis != 0: diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index 7bb532ae7f..ab61974f7a 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -26,6 +26,7 @@ from flax.core import FrozenDict from flax.core import meta from flax.nnx import graph +from flax.nnx import tracers as nnx_tracers from flax.nnx import variablelib from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module @@ -167,6 +168,39 @@ def current_linen_module() -> linen.Module | None: return None +def is_linen_initializing() -> bool: + """Check if the current execution context is inside a Linen init() call. + + Returns True when called from within a ``to_linen_class`` wrapper's + ``init()`` path. Uses :func:`current_linen_module` to access the Linen + module stack (private API already used by this module). + + This is used by NNX pipeline modules to short-circuit the full scan + during Linen init, where only the output shape/dtype is needed. + """ + module = current_linen_module() + if module is not None and hasattr(module, "is_initializing") and callable(module.is_initializing): + return module.is_initializing() + return False + + +def _refresh_variable_trace_state(module: Module) -> None: + """Refresh _trace_state for Variables that have stale trace state. + + When nnx.update() is called with tracer values from a JAX transformation + (e.g. jax.grad's LinearizeTracer), it uses _unsafe_bypass_check=True which + updates the raw value but not _trace_state. This leaves Variables with a + stale _trace_state from the outer (Python) context, causing nnx.split() to + fail with "Cannot extract graph node from different trace level" errors. + + This function resets _trace_state on any Variables whose _can_update is False + so that downstream NNX operations (e.g. nnx.split in NNXPipeline) succeed. + """ + for _, v in nnx.graph.iter_graph(module): + if isinstance(v, variablelib.Variable) and not v._can_update: # pylint: disable=protected-access + object.__setattr__(v, "_trace_state", nnx_tracers.TraceState()) + + class ToNNX(Module): """A wrapper to turn any Linen module into an NNX module. @@ -476,6 +510,7 @@ def maybe_unbox(x): warnings.warn(f"Found unknown module paths in incoming state:{paths_str}") nnx.update(module, new_state) + _refresh_variable_trace_state(module) _fix_for_qwix_quantization(module) method_fn = _get_module_method(module, nnx_method) diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index bf91262bf1..35611b2166 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -114,7 +114,17 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> return y_flat.reshape(input_shape) -def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs): +def Qwen3NextRMSNorm( + num_features: int, + epsilon: float = 1e-6, + dtype: DType = None, + weight_dtype: DType = None, + shard_mode=None, + kernel_axes=None, + parameter_memory_host_offload=None, + *, + rngs: nnx.Rngs, +): """ Used for input and post attention layernorms in Qwen3NextDecoderLayer. @@ -127,7 +137,7 @@ def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: return nnx.data( RMSNorm( num_features=num_features, - epsilon=eps, + epsilon=epsilon, dtype=dtype, weight_dtype=weight_dtype, scale_init=linen_initializers.zeros, diff --git a/src/maxtext/models/gpt_oss.py b/src/maxtext/models/gpt_oss.py index 9401d01d9f..5f4a2f3fb6 100644 --- a/src/maxtext/models/gpt_oss.py +++ b/src/maxtext/models/gpt_oss.py @@ -29,6 +29,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import moe from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations @@ -132,6 +133,8 @@ def __init__( rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + def __call__( self, inputs, @@ -189,7 +192,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index a75cefc291..6fc0e5d2f6 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -71,6 +71,7 @@ def __init__( shard_mode=config.shard_mode, kernel_axes=("norm",), epsilon=config.normalization_layer_epsilon, + parameter_memory_host_offload=config.parameter_memory_host_offload, rngs=rngs, ) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 1b0d4b4cd3..5ba365b74b 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -398,6 +398,19 @@ def no_op(self, *args, **kwargs): """A no-op method to allow the model to be used in a lazy context.""" return + def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): + """Compute logits from hidden states (wraps NNXDecoder.apply_output_head). + + Mirrors the Linen TransformerLinenPure.logits_from_hidden_states method; + used by vocabulary tiling to recompute logits from chunked hidden states. + """ + return self.decoder.apply_output_head( + shared_embedding=self.token_embedder, + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def init_cache(self, cache_size: int, batch_size: int, dtype=jnp.float32): """Initializes the KV cache for the Transformer. diff --git a/src/maxtext/models/olmo3.py b/src/maxtext/models/olmo3.py index 09c5b4e079..b743e8d4b7 100644 --- a/src/maxtext/models/olmo3.py +++ b/src/maxtext/models/olmo3.py @@ -30,6 +30,7 @@ from maxtext.common.common_types import AttentionType, Config from maxtext.layers import attentions from maxtext.layers import initializers +from maxtext.layers import linears from maxtext.layers import nnx_wrappers from maxtext.layers import quantizations from maxtext.layers.attentions import Attention @@ -142,6 +143,7 @@ def __init__( model_mode=model_mode, rngs=rngs, ) + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) def __call__( self, @@ -202,7 +204,7 @@ def __call__( mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed")) layer_output = mlp_lnx + intermediate_inputs - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint( layer_output, diff --git a/src/maxtext/models/qwen3.py b/src/maxtext/models/qwen3.py index bd65f04438..87cb4cc7ef 100644 --- a/src/maxtext/models/qwen3.py +++ b/src/maxtext/models/qwen3.py @@ -966,7 +966,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -991,7 +991,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/models/qwen3_5.py b/src/maxtext/models/qwen3_5.py index b25ecf09e8..143bf63a07 100644 --- a/src/maxtext/models/qwen3_5.py +++ b/src/maxtext/models/qwen3_5.py @@ -139,7 +139,7 @@ def __init__( # First LayerNorm, applied before the attention block. self.input_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, @@ -164,7 +164,7 @@ def __init__( # Second LayerNorm, applied before the MoE block. self.post_attention_layernorm = Qwen3NextRMSNorm( num_features=cfg.emb_dim, - eps=cfg.normalization_layer_epsilon, + epsilon=cfg.normalization_layer_epsilon, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, rngs=rngs, diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index a9ef64631a..39d84a89dc 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -26,6 +26,7 @@ from typing import Any, Callable import drjax +from flax import nnx from flax import struct from flax.training import train_state import jax @@ -153,7 +154,15 @@ def add_diloco_dim(x): momentum=config.diloco_outer_momentum, nesterov=True, ) - outer_opt_state = jax.eval_shape(outer_optimizer.init, abstract_state.params) + # For NNX, model params (Param variables only) live under abstract_state.model; + # for Linen under abstract_state.params. + if config.pure_nnx: + model_params = abstract_state.model.filter(nnx.Param) + model_params_sharding = state_mesh_shardings.model.filter(nnx.Param) + else: + model_params = abstract_state.params + model_params_sharding = state_mesh_shardings.params + outer_opt_state = jax.eval_shape(outer_optimizer.init, model_params) # Create abstract step abstract_step = jax.ShapeDtypeStruct((), jnp.int32) @@ -161,7 +170,7 @@ def add_diloco_dim(x): # Build abstract DiLoCo state diloco_state = DiLoCoTrainState( inner_state=inner_state, - params=abstract_state.params, + params=model_params, outer_opt_state=outer_opt_state, step=abstract_step, ) @@ -171,12 +180,12 @@ def add_diloco_dim(x): # Sharding for outer_opt_state. For SGD with momentum, it is (TraceState(trace=...), EmptyState()) # We shard the momentum trace the same way as the parameters. outer_opt_state_sharding = ( - optax.TraceState(trace=state_mesh_shardings.params), + optax.TraceState(trace=model_params_sharding), optax.EmptyState(), ) diloco_state_shardings = DiLoCoTrainState( inner_state=inner_state_shardings, - params=state_mesh_shardings.params, + params=model_params_sharding, outer_opt_state=outer_opt_state_sharding, step=None, ) @@ -205,11 +214,15 @@ def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: # mesh automatically when jax.set_mesh is used. inner_state = drjax.broadcast(state, mesh=mesh) # Outer state retains a single copy of the model parameters and optimizer state. - outer_params = state.params + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params outer_opt_state = outer_optimizer.init(outer_params) outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=state.step), + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), outer_opt_state_sharding, ) @@ -244,7 +257,11 @@ def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) - model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. + inner_model_params = ( + nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params + ) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) @@ -253,7 +270,27 @@ def synchronize(state): # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. - new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state, mesh=mesh) + if config.pure_nnx: + # For NNX: merge new Param vars back with the non-Param model vars (e.g. RNG state). + def replace_nnx_model_params(s, new_params): + non_param_model = nnx.filter_state(s.model, nnx.Not(nnx.Param)) + new_model = nnx.merge_state(non_param_model, new_params) + # Build result via __setitem__ so nested States are stored as plain dicts + # internally, matching the pytree structure produced by nnx.state(). + # (Passing State objects via the constructor dict literal stores them + # as-is, causing jax.lax.cond to see mismatched pytree structures.) + result = type(s)({}) + result["model"] = new_model + result["optimizer"] = s["optimizer"] + return result + + new_inner_state = drjax.map_fn( + lambda s: replace_nnx_model_params(s, new_outer_params), + state.inner_state, + mesh=mesh, + ) + else: + new_inner_state = drjax.map_fn(lambda s: s.replace(params=new_outer_params), state.inner_state, mesh=mesh) return state.replace( params=new_outer_params, outer_opt_state=new_opt_state, @@ -271,14 +308,16 @@ def diloco_train_step(state, batch, prng): broadcast_rng = drjax.broadcast(prng, mesh=mesh) inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng), mesh=mesh) avg_metrics = typed_reduce_mean(metrics) + # For NNX, the step counter lives at inner_state.optimizer.step; for Linen at inner_state.step. + new_step = inner_state.optimizer.step[0] if config.pure_nnx else inner_state.step[0] state = state.replace( inner_state=inner_state, - step=inner_state.step[0], + step=new_step, ) # Either synchronize the model, or no-op, depending on whether the current # step falls on the synchronization period. state = jax.lax.cond( - inner_state.step[0] % config.diloco_sync_period == 0, + new_step % config.diloco_sync_period == 0, synchronize, lambda x: x, # no-op state, diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 951d10585d..da14bc6172 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -72,7 +72,7 @@ from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad -from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss +from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss, vocab_tiling_nnx_loss _diag_modules = _cloud_diag() diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = _diag_modules @@ -203,9 +203,10 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr intermediate_outputs = intermediates.to_pure_dict() if config.num_vocab_tiling > 1: - raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") - - if (config.use_indexer and not config.indexer_sparse_training) and is_train: + hidden_state_key = ("decoder", "hidden_states") + hidden_states = maxtext_utils.get_nested_value(intermediate_outputs, hidden_state_key)[0] + xent_sum, total_z_loss = vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train) + elif (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. # The main model parameters are frozen and only the indexer is trained via KL divergence. xent_sum = 0.0 @@ -323,7 +324,12 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: if config.use_dpo: - raise NotImplementedError("DPO for NNX modules has not been implemented.") + raise NotImplementedError( + "DPO is not yet supported for NNX modules. DPO requires a reference model " + "stored alongside the policy model (Linen path uses state.params['reference_params']); " + "the NNX TrainState equivalent has not been wired up. As a workaround, set " + "pure_nnx=False for DPO runs." + ) state = nnx.merge(model, state) # reconstruct TrainStateNNX ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] @@ -557,7 +563,11 @@ def move(path, value): if config.use_dpo: new_state = _merge_dpo_state(new_state, reference_params) return new_state, metrics - return nnx.state(new_state), metrics + # Exclude Intermediate variables (e.g., sowed max_logits for QK-Clip) from the + # returned state. Intermediates are transient forward-pass artifacts and must not + # persist across steps: they're absent from the abstract state used to build + # state_mesh_shardings, so including them would cause a leaf-count mismatch in JAX. + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics def eval_step(model, config, state, data, dropout_rng=None): diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 831e97b885..ad595a1632 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -29,6 +29,7 @@ from flax import nnx from flax.linen import partitioning as nn_partitioning import jax +import jax.numpy as jnp from jax.experimental.serialize_executable import serialize from jax.experimental.topologies import get_topology_desc from jax.sharding import AxisType, Mesh @@ -91,6 +92,27 @@ def get_topology_mesh(config): return topology_mesh +def _collect_nnx_activation_shardings(create_model_fn, config, mesh): + """Run an NNX forward pass in abstract mode to populate _ACTIVATION_SHARDINGS_DUMP. + + get_abstract_state_nnx uses nnx.eval_shape which only traces model initialization, + not __call__. Activation shardings are only collected during a forward pass. + """ + input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) + + def _nnx_forward(): + model_instance = create_model_fn() + return model_instance( + decoder_input_tokens=jnp.ones(input_shape, dtype=jnp.int32), + decoder_positions=jnp.ones(input_shape, dtype=jnp.int32), + decoder_segment_ids=jnp.ones(input_shape, dtype=jnp.int32), + enable_dropout=False, + ) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + jax.eval_shape(_nnx_forward) + + def get_shaped_inputs(topology_mesh, config): """Get shaped abstractions of inputs to train_step: state, batch and rng""" # Construct the model and optimizer to get shaped versions of the state @@ -128,7 +150,8 @@ def create_train_state_fn(): # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): - graphdef, _ = nnx.get_abstract_model(init_state_fn, topology_mesh) + abs_train_state = nnx.eval_shape(init_state_fn) + graphdef, _ = nnx.split(abs_train_state) model = graphdef else: # unsharded logical annotations @@ -138,10 +161,17 @@ def create_train_state_fn(): shaped_batch = maxtext_utils.get_shaped_batch(config) if config.pure_nnx: - shaped_train_args = (abstract_state, shaped_batch, None) # NNX doesn't use dropout_rng + shaped_train_args = (abstract_state, shaped_batch) # NNX doesn't use dropout_rng else: shaped_train_args = (abstract_state, shaped_batch, shaped_rng) shaped_train_kwargs = {} + + # Collect activation shardings for NNX by running an abstract forward pass. + # This must happen after get_abstract_state (which uses nnx.eval_shape and only + # traces __init__, not __call__). + if config.debug_sharding and config.pure_nnx: + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) + return shaped_train_args, shaped_train_kwargs, state_mesh_shardings, logical_annotations, model @@ -299,7 +329,9 @@ def main(argv: Sequence[str]) -> None: diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( config, abstract_state, state_mesh_shardings, topology_mesh ) - shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2]) + # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None + shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings) diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 2fd14b87a2..0f997a6577 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -90,20 +90,17 @@ def slice_ith(input_layers): def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # This script reads a Linen-format full state and emits a Linen-format + # parameter-only checkpoint (downstream `_possibly_unroll_params` and + # `_save_decode_checkpoint` access `state.params["params"]["decoder"]` / `state.opt_state`, + # both Linen-only). Use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) @@ -114,12 +111,11 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # Model and Optimizer definition + # Model and Optimizer definition. + # LoRA adapters and downstream `_save_decode_checkpoint`/`_possibly_unroll_params` + # are Linen-shaped; use the Linen path regardless of pure_nnx. quant = quantizations.configure_quantization(config) - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) tx = optimizers.get_optimizer(config, learning_rate_schedule) diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index e1699647c6..cf84577dbd 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -71,10 +71,16 @@ def _maybe_shard_with_name(inputs, sharding_names): is_nnx = isinstance(model, nnx.Module) - # For more efficient DP/ZeRO-1 + GA - if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: - ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) - grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + # For more efficient DP/ZeRO-1 + GA. + # config.ici_data_parallelism may be -1 (auto-fill: resolved at mesh creation time, but + # the config field remains -1). Treat any value != 1 as "data parallelism is active". + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # jax.lax.scan traces its body with an AbstractMesh where all axis types are Auto, + # which rejects reduced/unreduced PartitionSpec in scan carry tensors (raises ValueError). + # Use plain params_shardings for ga_params and init_grad in the carry. + # The all-reduce for data parallelism is applied to raw_grads after the scan instead. + ga_params_shardings = params_shardings + grad_shardings = params_shardings else: ga_params_shardings = grad_shardings = params_shardings @@ -105,7 +111,7 @@ def accumulate_gradient(acc_grad_and_loss, data): if is_nnx: # Reconstruct the model using the fixed parameters (ga_params) # and the advancing non-parameter state (RNGs) from the carry. - local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"], copy=True) (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) acc_grad_and_loss["rest_state"] = next_rest_state @@ -156,6 +162,11 @@ def reshape_to_microbatch_accumulations(batch_arr): + grad_and_loss["mtp_loss"] / config.gradient_accumulation_steps ) raw_grads = grad_and_loss["grad"] + if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism != 1: + # Apply unreduced annotation after the scan to trigger all-reduce across data-parallel + # devices (reduced/unreduced cannot be used inside jax.lax.scan carry tensors). + unreduced_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) + raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, unreduced_shardings) raw_grads = jax.tree.map(_maybe_shard_with_name, raw_grads, params_shardings) raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr diff --git a/src/maxtext/utils/layerwise_quantization.py b/src/maxtext/utils/layerwise_quantization.py index 29fa928656..a6c1c07f67 100644 --- a/src/maxtext/utils/layerwise_quantization.py +++ b/src/maxtext/utils/layerwise_quantization.py @@ -173,19 +173,15 @@ def __init__(self, config: Any, rng: PRNGKeyType): devices_array = maxtext_utils.create_device_mesh(config=config) self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) - # Model and quantization config + # Model and quantization config. + # This script produces and consumes Linen-format checkpoints (see DeepSeek*ToLinen + # layer classes used in load_and_quantize). Always use the Linen path internally, + # regardless of the pure_nnx flag — the flag affects training, not checkpoint format. self.quant = quantizations.configure_quantization(config) - if self.config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = models.transformer_as_linen( - config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN - ) - if self.config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) + model = models.transformer_as_linen( + config, mesh=self._mesh, quant=self.quant, model_mode=common_types.MODEL_MODE_TRAIN + ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, self.config, False, self.rng) self.unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(self.config, self._mesh, init_state_fn, False) diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 8554d46e3e..1efad6aa91 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Common LoRA utils needed to support LoRA adapters.""" +"""Common LoRA utils needed to support LoRA adapters.""" + + from functools import partial import json import os @@ -174,11 +176,10 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # LoRA adapters are Linen-format on disk (downstream `get_lora_abstract_state` expects + # `unboxed_abstract_state.params` Linen tree shape; `lora_state.replace(params=...)` + # uses Linen TrainState API). Use the Linen init path regardless of the pure_nnx flag. + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3bd2b186b1..049a084979 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -116,6 +116,7 @@ def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + # The result is an nnx.State with the same structure, where each Param's value holds the mdn result. muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) else: # Linen @@ -154,7 +155,7 @@ def get_leaf_info(leaf): print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=True): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -191,6 +192,8 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) + if pure_nnx: + muon_weight_dimension_numbers = {"params": nnx.to_pure_dict(muon_weight_dimension_numbers)} return muon_weight_dimension_numbers diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index ba6b148b04..2fc2b09e25 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -52,18 +52,15 @@ def checkpoint_loop(config, state=None): Returns: """ - if config.pure_nnx: - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - model = from_config(config) + # Standalone checkpointer is a save/restore exerciser that uses + # add_entropy_to_checkpoint() to populate Linen-shaped optimizer state + # (state.opt_state, state.params). Use the Linen path regardless of pure_nnx — + # the flag affects training, not this checkpoint test harness. + model = from_config(config) mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) _, tx = train_utils.create_training_optimizer(config, model) - if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") - else: - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index e7b155416c..6a61f9ed23 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -247,3 +247,110 @@ def _bwd_scan_body(grad_params_acc, chunk_data): ) return total_loss, total_z_loss + + +def vocab_tiling_nnx_loss(model, hidden_states, data, config, is_train): + """Calculates cross-entropy loss using vocab tiling for NNX models. + + NNX equivalent of `vocab_tiling_linen_loss`. Iterates the vocab dimension via + `jax.lax.scan` with `model.logits_from_hidden_states` per chunk; the model + carries its parameters internally so no explicit gather is needed. + + This is a memory-efficient forward (chunked logits) but uses the default + autograd path (no custom_vjp), so backward memory savings vs. the Linen + custom_vjp path are not yet realized. TODO: add a custom_vjp using + `nnx.split`/`nnx.merge` if backward memory becomes a concern. + + Args: + model: The NNX model instance (must implement `logits_from_hidden_states`). + hidden_states: The final hidden states from the decoder. + data: A dictionary containing the input data, including 'targets' and 'targets_segmentation'. + config: The model and training configuration. + is_train: A boolean indicating if the model is in training mode. + + Returns: + A tuple (total_loss, total_z_loss). + """ + labels = data["targets"] + segmentation = data["targets_segmentation"] + deterministic = not config.enable_dropout if is_train else True + model_mode = "train" + + hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length", "activation_embed"), + ) + label_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch", "activation_length"), + ) + reshaped_hidden_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + reshaped_data_spec = create_sharding( + model.mesh, + ("num_tile", "activation_embed_and_logits_batch_sequence"), + ) + chunked_hidden_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ) + chunked_data_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence",), + ) + chunked_logits_spec = create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_vocab"), + ) + + _maybe_shard_with_name = functools.partial( + maybe_shard_with_name, + shard_mode=config.shard_mode, + debug_sharding=config.debug_sharding, + extra_stack_level=1, + ) + + def _reshape(inputs, out_shape, out_sharding): + reshape_out_sharding = out_sharding if config.shard_mode == ShardMode.EXPLICIT else None + inputs = jax.lax.reshape(inputs, out_shape, out_sharding=reshape_out_sharding) + return _maybe_shard_with_name(inputs, out_sharding) + + hidden_states = _maybe_shard_with_name(hidden_states, hidden_spec) + labels = _maybe_shard_with_name(labels, label_spec) + segmentation = _maybe_shard_with_name(segmentation, label_spec) + + batch_size, seq_len, emb_dim = hidden_states.shape + vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling + + reshaped_hidden_states = _reshape( + hidden_states, (config.num_vocab_tiling, vocab_tile_size, emb_dim), reshaped_hidden_spec + ) + reshaped_labels = _reshape(labels, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + reshaped_segmentation = _reshape(segmentation, (config.num_vocab_tiling, vocab_tile_size), reshaped_data_spec) + + def _scan_body(accumulators, chunk_data): + loss_accumulator, z_loss_accumulator = accumulators + hidden_chunk, label_chunk, segmentation_chunk = chunk_data + hidden_chunk = _maybe_shard_with_name(hidden_chunk, chunked_hidden_spec) + label_chunk = _maybe_shard_with_name(label_chunk, chunked_data_spec) + segmentation_chunk = _maybe_shard_with_name(segmentation_chunk, chunked_data_spec) + + chunk_logits = model.logits_from_hidden_states(hidden_chunk, deterministic, model_mode) + chunk_logits = _maybe_shard_with_name(chunk_logits, chunked_logits_spec) + one_hot_label_chunk = jax.nn.one_hot(label_chunk, config.vocab_size) + chunk_xent, chunk_z_loss = max_utils.cross_entropy_with_logits( + chunk_logits, one_hot_label_chunk, z_loss=config.z_loss_multiplier + ) + + masked_xent = jnp.sum(chunk_xent * (segmentation_chunk != 0)) + masked_z_loss = jnp.sum(chunk_z_loss * (segmentation_chunk != 0)) + + return (loss_accumulator + masked_xent, z_loss_accumulator + masked_z_loss), None + + initial_acc = (jnp.zeros((), dtype=hidden_states.dtype), jnp.zeros((), dtype=hidden_states.dtype)) + (total_loss, total_z_loss), _ = jax.lax.scan( + _scan_body, initial_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) + ) + return total_loss, total_z_loss diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index 3495b4c557..f532820f86 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -154,13 +154,6 @@ def test_indexer_dense_warmup_skips_xent(self): self.assertEqual(float(aux["xent_sum"]), 0.0) self.assertEqual(float(loss), 0.0) - def test_vocab_tiling_raises_not_implemented(self): - cfg, ts = _build_state() - cfg.num_vocab_tiling = 4 - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) - class TestTrainStepNNX(unittest.TestCase): """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" From f42a86351a0d6ca3fb5616b97a86cfe4afb0563a Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 29 Apr 2026 16:07:35 +0000 Subject: [PATCH 2/4] NNX: native DPO (TrainStateNNX.reference_model + dpo_loss_fn_nnx) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NNX-native DPO so that the pure_nnx=True training path no longer raises NotImplementedError on use_dpo runs. The Linen DPO overlay pattern (model.apply(params=..., reference_params=...)) does not translate to NNX modules, which carry their parameters internally. Instead the policy and reference models are held as separate nnx.Module instances on TrainStateNNX, and a new dpo_loss_fn_nnx runs both forwards with stop_gradient on the reference logits. TrainStateNNX: - Add optional `reference_model: nnx.Module` field. apply_gradients continues to update only `self.model`, leaving `self.reference_model` bit-identical across steps. dpo_utils.py: - Add dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True). Signature mirrors the Linen dpo_loss_fn so it slots into gradient_accumulation_loss_and_grad's dispatcher (dropout_rng / params slots are unused for NNX; carried for parity, and reference_model is passed as the single extra_dpo_args entry). With nnx.value_and_grad(..., argnums=0) over the policy, no gradient flows to the reference model's nnx.Param leaves; the explicit jax.lax.stop_gradient on ref_logits is a belt-and-braces guard. - Both dpo_loss_fn (Linen) and dpo_loss_fn_nnx (NNX) now include indexer_loss=0.0 and mtp_loss=0.0 in aux so the gradient_accumulation aux pytree shape matches the non-DPO loss_fn. train.py: - Drop the NotImplementedError in train_step's NNX branch. When use_dpo, dispatch to dpo_loss_fn_nnx with state.reference_model as extra_dpo_args; otherwise use the regular loss_fn. eval_step gains the same dispatch. - diff_wrapper picks _loss_fn / extra_dpo_args from the per-path init block, so both the GA and non-GA NNX paths route DPO identically. - Checkpoint-save _split_dpo_state stripping is now Linen-only; TrainStateNNX saves whole (reference_model included) — the step-0 reload later overwrites reference_model from the step-0 checkpoint. train_utils.py: - NNX init_state_fn materializes a frozen reference_model alongside the policy when config.use_dpo. Both are constructed by _create_model_partial() with config.init_weights_seed, so they start identical (standard DPO practice) until the step-0 reload. - Step-0 checkpoint reload: copy step0_state["model"] into state["reference_model"]. Linen path unchanged. Tests: - New tests/unit/dpo_nnx_test.py (7 tests): TrainStateNNX reference_model init/hasattr semantics; apply_gradients leaves reference bit-identical; aux key set; identical policy/reference yields loss=log(2) and reward_accuracy=0.0 (strict > on equal logratios); dropout_rng/params slots are signature-compat only; nnx.value_and_grad(argnums=0) over the policy yields finite grads on policy params only. - train_nnx_test.py: drop the two stale negative tests (vocab_tiling_raises_not_implemented, train_step_dpo_raises_for_nnx) — both features are now real. Stats: 4 source files + 2 test files, +199/-22 source lines. Linen DPO path behaviorally unchanged (only adds two harmless aux-dict keys); NNX non-DPO path unchanged (all changes gated on config.use_dpo). --- src/maxtext/layers/train_state_nnx.py | 24 +- .../trainers/post_train/dpo/dpo_utils.py | 139 +++++++++++ src/maxtext/trainers/pre_train/train.py | 34 +-- src/maxtext/utils/train_utils.py | 24 +- .../integration/setup_train_loop_nnx_test.py | 9 - tests/unit/dpo_nnx_test.py | 215 ++++++++++++++++++ tests/unit/train_nnx_test.py | 10 - 7 files changed, 412 insertions(+), 43 deletions(-) create mode 100644 tests/unit/dpo_nnx_test.py diff --git a/src/maxtext/layers/train_state_nnx.py b/src/maxtext/layers/train_state_nnx.py index 9ef0e6dffd..3f9ee1ce29 100644 --- a/src/maxtext/layers/train_state_nnx.py +++ b/src/maxtext/layers/train_state_nnx.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" The NNX Unified TrainState. """ +"""The NNX Unified TrainState.""" from typing import Any @@ -25,20 +25,34 @@ class TrainStateNNX(nnx.Module): This replaces Linen's TrainState for checkpointing. Linen TrainState pytree: - {“params”: {...}, “opt_state”: {}...} + {"params": {...}, "opt_state": {}...} TrainStateNNX state pytree: - {“model”: {...}, “optimizer”: {“opt_state”: {...}} + {"model": {...}, "optimizer": {"opt_state": {...}}} + + For DPO (Direct Preference Optimization), an optional `reference_model` + carries a frozen copy of the same architecture used to compute reference + log-probabilities. Only `model` is updated by `apply_gradients`; the + reference is held alongside so it is sharded, jit-traced, and checkpointed + with the rest of the train state. """ - def __init__(self, model: nnx.Module, optimizer: nnx.Optimizer | None): + def __init__( + self, + model: nnx.Module, + optimizer: nnx.Optimizer | None, + reference_model: nnx.Module | None = None, + ): self.model = model self.optimizer = optimizer + if reference_model is not None: + self.reference_model = reference_model def apply_gradients(self, grads: Any): """ Mimics the Linen apply_gradients function. Updates the optimizer state, applies updates to parameters, - and increments the step counter. + and increments the step counter. Only updates `self.model`; + `self.reference_model` (if present) is left untouched. """ if self.optimizer is None: raise RuntimeError( diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py index eeda1c1a7f..fd5faa5c9c 100644 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ b/src/maxtext/trainers/post_train/dpo/dpo_utils.py @@ -19,6 +19,8 @@ import jax import jax.numpy as jnp +from flax import nnx + from maxtext.utils import maxtext_utils @@ -148,6 +150,8 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t "total_weights": total_weights, "moe_lb_loss": moe_lb_loss, "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility } return loss, aux @@ -155,3 +159,138 @@ def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_t def _merge_dpo_state(state, reference_params): """Merge reference parameters back into DPO state.""" return state.replace(params=dict(state.params, reference_params=reference_params)) + + +# NNX DPO has no split/merge counterpart: the Linen path overlays +# `reference_params` inside `state.params`, so it must be peeled off and +# reattached around `apply_gradients`. The NNX path holds the reference as a +# sibling field `TrainStateNNX.reference_model`; `apply_gradients` already +# only touches `self.model`, so no split/merge is needed. + + +def dpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """NNX DPO loss_fn for both train and eval. + + Signature mirrors the Linen `dpo_loss_fn` so it slots into the same + dispatcher in `gradient_accumulation_loss_and_grad`: + `(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True)` + + Differences from the Linen `dpo_loss_fn`: + * `policy_model` is an `nnx.Module` (carries its own params + RNG state). + * `dropout_rng` and `params` are unused for NNX (kept positional for + signature parity; NNX models manage these internally). + * The 6th arg (the `extra_dpo_args[0]`) is a frozen reference + `nnx.Module`, not a `reference_params` pytree. + * Reference forward is wrapped in `jax.lax.stop_gradient`; combined with + `nnx.value_and_grad(..., argnums=0)` over the policy, no gradient flows + to the reference's `nnx.Param` leaves. + + Args: + policy_model: Policy `nnx.Module` (the model being trained). + config: Config of parameters. + data: Batch of preference data with `chosen` / `rejected` fields. + dropout_rng: Unused for NNX (kept for signature parity with Linen). + params: Unused for NNX (kept for signature parity with Linen). + reference_model: Frozen reference `nnx.Module` for DPO logratio computation. + is_train: True for train_step and False for eval_step. + + Returns: + loss: DPO preference loss + MoE load balance loss (if applicable). + aux: dict with intermediate_outputs, xent_sum (always 0.0), dpo_loss, + total_weights, moe_lb_loss, reward_accuracy. + """ + del dropout_rng, params # unused for NNX + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.micro_batch_size_to_train_on, :] + + # for DPO we don't support packed sequences (they shouldn't be present in the first place) + data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) + data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) + data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) + data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) + + # concatenated policy/reference forward pass + inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) + inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) + inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) + + logits = policy_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=config.enable_dropout if is_train else False, + ) + intermediate_outputs = nnx.state(policy_model, nnx.Intermediate).to_pure_dict() + + ref_logits = reference_model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=False, + ) + ref_logits = jax.lax.stop_gradient(ref_logits) + + # extract token ids, segmentation and logits for chosen and rejected sequences + chosen_ids = data["chosen"][..., 1:] + rejected_ids = data["rejected"][..., 1:] + chosen_segmentation = data["chosen_segmentation"][..., 1:] + rejected_segmentation = data["rejected_segmentation"][..., 1:] + n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] + chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] + chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] + + # common subsequence and padding mask + common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] + valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] + + # compute logratios from the sequence-reduced observed token log-probability + chosen_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 + )[..., 0] + chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + chosen_logratios = chosen_logps - chosen_ref_logps # [B] + + rejected_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] + jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 + )[..., 0] + rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] + rejected_logratios = rejected_logps - rejected_ref_logps # [B] + + # DPO loss from chosen and rejected logratios + LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta + logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] + losses = ( # [B] + -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) + - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING + ) + total_loss, total_weights = jnp.mean(losses), losses.shape[0] + loss = total_loss + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) + aux = { + "intermediate_outputs": intermediate_outputs, + "xent_sum": 0.0, # DPO has no per-token cross-entropy sum; set to 0 for train_step compatibility + "dpo_loss": total_loss, # pure preference loss before MoE lb, analogous to lm_loss in pre-training + "total_weights": total_weights, + "moe_lb_loss": moe_lb_loss, + "reward_accuracy": reward_accuracy, + "indexer_loss": 0.0, # for gradient_accumulation aux pytree compatibility + "mtp_loss": 0.0, # for gradient_accumulation aux pytree compatibility + } + return loss, aux diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index da14bc6172..60fcfd0f12 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -61,7 +61,7 @@ from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common import metric_logger from maxtext.common.metric_logger import record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn +from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn, dpo_loss_fn_nnx from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -323,15 +323,15 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat params = state.params ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args else: - if config.use_dpo: - raise NotImplementedError( - "DPO is not yet supported for NNX modules. DPO requires a reference model " - "stored alongside the policy model (Linen path uses state.params['reference_params']); " - "the NNX TrainState equivalent has not been wired up. As a workaround, set " - "pure_nnx=False for DPO runs." - ) state = nnx.merge(model, state) # reconstruct TrainStateNNX - ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] + if config.use_dpo: + # NNX DPO: reference_model is a sibling field on TrainStateNNX (set up by + # init_initial_state when config.use_dpo=True). dpo_loss_fn_nnx mirrors + # the Linen dpo_loss_fn signature, so it slots into the same dispatcher + # with reference_model passed as the single extra_dpo_args entry. + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = (dpo_loss_fn_nnx, state.model, None, None, [state.reference_model]) + else: + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] # --- Gradient computation --- if config.gradient_accumulation_steps > 1: @@ -397,9 +397,14 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) nnx.update(state.model, curr_params) + # `ga_fn` and `ga_dpo` were set up earlier (loss_fn vs dpo_loss_fn_nnx; + # ga_dpo carries the frozen reference_model when use_dpo, else empty). + _nnx_loss_fn = ga_fn + _nnx_extra_dpo_args = ga_dpo + def diff_wrapper(param, rest, config, data): local_model = nnx.merge(model_graphdef, param, rest, copy=True) - loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + loss, aux = _nnx_loss_fn(local_model, config, data, None, None, *_nnx_extra_dpo_args, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) @@ -587,7 +592,10 @@ def eval_step(model, config, state, data, dropout_rng=None): loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) else: state = nnx.merge(model, state) # reconstruct TrainStateNNX - loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) + if config.use_dpo: + loss, aux = dpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + else: + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -714,7 +722,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): @@ -758,7 +766,7 @@ def train_loop(config, recorder, state=None): metric_logger_instance.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + state_to_save = state if not (config.use_dpo and not config.pure_nnx) else _split_dpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ca90550630..80229b05be 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -225,10 +225,16 @@ def setup_train_loop(config, recorder, devices=None): if config.pure_nnx: # For NNX, the train state is wrapped in the TrainStateNNX module. + # When DPO is enabled, also materialize a frozen reference model alongside + # the policy. Both are constructed by `_create_model_partial()` (which uses + # `config.init_weights_seed`), so the reference starts identical to the + # policy — standard DPO practice. The reference is later overwritten by + # the step-0 checkpoint in `setup_post_setup_state` below. def create_train_state_fn(): model = _create_model_partial() optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) - return train_state_nnx.TrainStateNNX(model, optimizer) + reference_model = _create_model_partial() if config.use_dpo else None + return train_state_nnx.TrainStateNNX(model, optimizer, reference_model=reference_model) init_state_fn = create_train_state_fn else: @@ -316,8 +322,6 @@ def create_train_state_fn(): maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: - if config.pure_nnx: - raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -342,9 +346,17 @@ def create_train_state_fn(): except FileNotFoundError: step0_restored = None if step0_restored is not None: - # TODO: For pure_nnx, the dpo state manipulation is different. - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) + if config.pure_nnx: + # step0_restored["items"] is the flat nnx.State of the step-0 TrainStateNNX + # (typically from a non-DPO pre-training run, so its top-level fields are + # `model` and `optimizer` — no `reference_model`). Copy its `model` substate + # into our current state's `reference_model` slot. + step0_state = step0_restored["items"] + step0_model_substate = step0_state["model"] if "model" in step0_state else step0_state + state["reference_model"] = step0_model_substate + else: + reference_params = step0_restored["items"].params["params"] + state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py index d11f9658a7..05a7fcffec 100644 --- a/tests/integration/setup_train_loop_nnx_test.py +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -126,15 +126,6 @@ def test_pure_nnx_setup_param_only_split_matches_model(self): del model - def test_pure_nnx_dpo_raises_not_implemented(self): - """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" - # use_dpo requires a few prerequisites; the simplest is to set the flag and - # let setup_train_loop reach the NotImplementedError check before the more - # involved DPO path runs. - config = _tiny_nnx_pyconfig(use_dpo=True, packing=False) - with self.assertRaises(NotImplementedError): - setup_train_loop(config, recorder=None) - if __name__ == "__main__": unittest.main() diff --git a/tests/unit/dpo_nnx_test.py b/tests/unit/dpo_nnx_test.py new file mode 100644 index 0000000000..461c3cb2aa --- /dev/null +++ b/tests/unit/dpo_nnx_test.py @@ -0,0 +1,215 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NNX DPO unit tests. + +Covers the NNX-native DPO surface: + * `TrainStateNNX(model, optimizer, reference_model=...)` — reference model + sits alongside policy and is not touched by `apply_gradients`. + * `dpo_loss_fn_nnx(policy, config, data, None, None, reference, is_train)` — + aux structure, identical-model invariant (loss = log(2), reward_accuracy = 0.5). +""" + +import math +import types +import unittest + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.post_train.dpo import dpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX transformer-shaped module for DPO tests. + + Accepts the same keyword args that `dpo_loss_fn_nnx` passes: + `decoder_input_tokens`, `decoder_positions`, `decoder_segment_ids`, + `enable_dropout`. Other args are tolerated via **kwargs. + """ + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_dpo_config(**overrides): + """Build the minimal config surface that `dpo_loss_fn_nnx` reads.""" + base = { + "dpo_label_smoothing": 0.0, + "dpo_beta": 0.1, + "enable_dropout": False, + "num_experts": 1, + "micro_batch_size_to_train_on": 2, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_dpo_batch(batch_size=2, seq_len=5): + """Build a tiny DPO-shaped batch. + + `chosen` and `rejected` share the first 2 tokens (common prefix is masked + out in the loss), differ at positions 2 and 3, and are padded at position 4. + """ + chosen = jnp.array([[1, 2, 3, 4, 0]] * batch_size, dtype=jnp.int32) + rejected = jnp.array([[1, 2, 5, 6, 0]] * batch_size, dtype=jnp.int32) + positions = jnp.tile(jnp.arange(seq_len, dtype=jnp.int32), (batch_size, 1)) + segmentation = jnp.array([[1, 1, 1, 1, 0]] * batch_size, dtype=jnp.int32) + return { + "chosen": chosen, + "rejected": rejected, + "chosen_position": positions, + "rejected_position": positions, + "chosen_segmentation": segmentation, + "rejected_segmentation": segmentation, + } + + +class TestTrainStateNNXWithReferenceModel(unittest.TestCase): + """`TrainStateNNX(reference_model=...)` semantics.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(1)) + self.tx = optax.adam(1e-3) + + def test_init_with_reference(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + self.assertIs(state.model, self.policy) + self.assertIs(state.reference_model, self.reference) + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_reference_omits_attribute(self): + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer) + self.assertFalse(hasattr(state, "reference_model")) + + def test_apply_gradients_does_not_touch_reference(self): + """Gradient update on policy must leave reference model bit-identical.""" + optimizer = nnx.Optimizer(self.policy, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.policy, optimizer, reference_model=self.reference) + + ref_kernel_before = jnp.asarray(state.reference_model.proj.kernel.value).copy() + + def policy_loss(m): + return jnp.mean(m(jnp.array([[1, 2]])) ** 2) + + grads = nnx.grad(policy_loss)(state.model) + state.apply_gradients(grads) + + ref_kernel_after = jnp.asarray(state.reference_model.proj.kernel.value) + self.assertTrue(jnp.array_equal(ref_kernel_before, ref_kernel_after)) + + +class TestDPOLossFnNNX(unittest.TestCase): + """`dpo_loss_fn_nnx` numerical and structural sanity checks.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Reference initialized with the same seed to make policy and reference + # bit-identical at construction time. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_dpo_config() + self.data = _make_dpo_batch() + + def test_aux_has_expected_keys(self): + _, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + expected_keys = { + "intermediate_outputs", + "xent_sum", + "dpo_loss", + "total_weights", + "moe_lb_loss", + "reward_accuracy", + "indexer_loss", + "mtp_loss", + } + self.assertEqual(set(aux.keys()), expected_keys) + self.assertEqual(aux["xent_sum"], 0.0) + self.assertEqual(aux["moe_lb_loss"], 0.0) # num_experts=1 + self.assertEqual(aux["total_weights"], self.data["chosen"].shape[0]) + + def test_identical_policy_and_reference_yields_log2_loss(self): + """When policy == reference, all logratios are 0; with label_smoothing=0 + the per-example loss is `-log(sigmoid(0)) = log(2)`. `reward_accuracy` + uses strict `chosen > rejected`, so equal logratios score 0.0 (no example + is strictly preferred). + """ + loss, aux = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + self.assertAlmostEqual(float(loss), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["dpo_loss"]), math.log(2.0), places=4) + self.assertAlmostEqual(float(aux["reward_accuracy"]), 0.0, places=4) + + def test_dropout_rng_and_params_args_are_unused(self): + """The 4th and 5th positional args are signature-compat slots for the + Linen dispatcher; passing arbitrary values must not affect the result. + """ + loss_a, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + loss_b, _ = dpo_utils.dpo_loss_fn_nnx( + self.policy, + self.config, + dict(self.data), + jax.random.PRNGKey(123), # dropout_rng — unused + {"params": "garbage"}, # params — unused + self.reference, + is_train=True, + ) + self.assertAlmostEqual(float(loss_a), float(loss_b), places=6) + + def test_value_and_grad_argnums0_only_diffs_policy(self): + """`nnx.value_and_grad(..., argnums=0)` over the policy should produce + finite grads on policy params and not require reference grads. + """ + + def _loss(policy_module): + loss, _ = dpo_utils.dpo_loss_fn_nnx( + policy_module, self.config, dict(self.data), None, None, self.reference, is_train=True + ) + return loss + + grad_fn = nnx.value_and_grad(_loss, argnums=0) + loss, grads = grad_fn(self.policy) + self.assertTrue(jnp.isfinite(loss)) + # Grads is an nnx.State of the policy's nnx.Param leaves; check at least one + # leaf is finite and non-trivially shaped. + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py index f532820f86..4340d4e22a 100644 --- a/tests/unit/train_nnx_test.py +++ b/tests/unit/train_nnx_test.py @@ -174,16 +174,6 @@ def test_train_step_returns_state_and_metrics(self): self.assertIn("learning/param_norm", metrics["scalar"]) self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) - def test_train_step_dpo_raises_for_nnx(self): - cfg, ts = _build_state() - cfg.use_dpo = True - state_graphdef, state_pure = nnx.split(ts) - data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) - with self.assertRaises(NotImplementedError): - pre_train.train_step( - state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data - ) - def test_train_step_increments_optimizer_step(self): cfg, ts = _build_state() state_graphdef, state_pure = nnx.split(ts) From cfc5fc4bbb3dc79c7bbacfafc77a2b0b38dd8577 Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 5 May 2026 20:53:02 +0000 Subject: [PATCH 3/4] NNX: native MaxEngine inference (drop route-to-Linen path in maxengine.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR5 audited maxengine.py and routed the inference path to the Linen implementation regardless of pure_nnx, with a comment block explaining that "the flag affects training, not inference serving." That kept the Linen serving path unchanged but meant pure_nnx=True users silently got the Linen engine. This change replaces the route with a real NNX flow: when config.pure_nnx=True, the engine builds an NNX Transformer, splits out (params, cache, rest) with nnx.split, and at every JIT body merges the model concretely with nnx.merge to run the forward pass. Linen is preserved byte-for-byte; every NNX edit is gated `if config.pure_nnx:` and pure_nnx=False is still the default. maxengine.py (__init__): - Build two abstract NNX Transformers on the NNX path: self.model with model_mode=PREFILL (batch=1, single padded prompt) and self.model_ar with model_mode=AUTOREGRESSIVE (batch=micro_batch_size_to_train_on, decode_state shape). Both are needed because NNX cache vars inherit CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and bulk_insert searches for the substring "cache_batch" in the AR-mode logical-axes tuple. nnx.eval_shape is called directly inside nn_partitioning.axis_rules rather than through create_nnx_abstract_model to avoid the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm" (same reason get_abstract_state_nnx avoids set_mesh). - Cache the graphdef from a 3-way nnx.split(Param, Cache, ...) so JIT bodies can pass (params, cache, rest) separately to nnx.merge. The rest slot (RNG vars etc.) is materialized concretely in load_params. maxengine.py (cache adapter + _nnx_run_model): - bulk_insert / _insert_jit / _maybe_*_prefill_result_cache walk the cache via tree_map_with_path and switch on path[-1].key (the cache variable name like "cached_prefill_key"). Linen mutable cache is a plain nested dict. NNX Cache state would expose a ".value" accessor at that position. Bridge via nnx.State.to_pure_dict() (after the model run) and nnx.replace_by_pure_dict (before nnx.merge), so the cache plumbing helpers see the same shape on both paths. - Add _nnx_run_model: nnx.merge(graphdef, params, cache, rest, copy=True) -> model(...) -> nnx.state(model, nnx.Cache).to_pure_dict(). copy=True avoids reusing Variable objects across traces (TraceContextError), mirroring train.py's diff_wrapper workaround. - Add _nnx_cache_state_template / _nnx_init_cache_dict helpers parametrised by mode so prefill (batch 1) and decode_state (batch N) pull from the right abstract model. maxengine.py (load_params): - New _load_params_nnx: accepts user-provided NNX-shape params or loads via from_pretrained. For user-provided params, materializes a concrete model once via _create_model_fn() to capture a real rest state for nnx.merge (wasteful but simple; the from_pretrained branch avoids this). Refreshes self.graphdef from the concrete model so subsequent merges line up exactly. - Builds self.abstract_params, populates self.prefill_kv_cache_annotations and self.kv_cache_annotations (using model_ar for the latter so bulk_insert's substring lookup hits), wraps both into NamedSharding. - pure_nnx + quantization, pure_nnx + LoRA, pure_nnx + stack_prefill_result_cache=True, pure_nnx + prefill_multisampling, and pure_nnx + prefill_concat raise NotImplementedError for now; the Linen path is the workaround. AOT compilation (aot_compile / _compile_generate_and_get_layouts) is not gated and may work as-is; not exercised by tests yet. maxengine.py (init_decode_state, _prefill_jit, _generate_jit): - _init_decode_state_nnx zero-initializes a pure-dict cache from model_ar (so the leading batch dim matches generate's input shape) and builds kv_cache_annotations_named per leaf by reading nnx.Cache.metadata. Tries "out_sharding", "sharding", and "sharding_names" because Flax 0.12.6 renamed these. - _prefill_jit / _generate_jit add an `if config.pure_nnx:` branch that calls _nnx_run_model in place of self.model.apply with mutable=["cache"]. existing_prefix.cache is threaded as a pure-dict cache directly (no params|{"cache":...} dict-merge — params is an nnx.State, not a dict). maxtext_utils.py: - New get_prefill_kv_cache_annotations_nnx / get_kv_cache_annotations_nnx that mirror the Linen helpers' return shape (per-leaf PartitionSpec tree). Both delegate to _nnx_cache_partition_specs which extracts nnx.Cache state via nnx.split, calls get_nnx_named_sharding_with_scan_axis inside nn_partitioning.axis_rules so logical axes ("layers", "cache_batch", "norm", ...) resolve to physical mesh axes, and converts the result to a pure-dict tree. tests/unit/maxengine_test.py: - New tests: test_init_nnx, test_basic_prefill_nnx (with NaN/inf and per-layer cache shape checks), test_basic_decode_nnx (4-step generate with next_pos advancement check), test_quantize_raises_for_nnx, test_lora_raises_for_nnx. - New test_linen_nnx_parity_prefill: bridges Linen-init params into the NNX engine via linen_nnx_converter (convert_linen_to_nnx -> _strip_value_wrappers -> nnx.replace_by_pure_dict) and asserts the NNX engine's prefill matches Linen on the same weights — logits within bf16 tolerance (rtol=0.05, atol=0.1; the test config uses bf16 compute) and exact greedy first-token argmax. - Existing Linen tests untouched. Test summary: 9 passed, 1 skipped (test_chunked_prefill is a pre-existing CPU-only skip). bash lint.sh: codespell + pylint + pyink all green. --- src/maxtext/inference/maxengine/maxengine.py | 317 +++++++++++++++++-- src/maxtext/utils/maxtext_utils.py | 24 ++ tests/integration/maxengine_test.py | 170 ++++++++++ 3 files changed, 478 insertions(+), 33 deletions(-) diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index c00f475e8d..5bd220f4e1 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -32,6 +32,7 @@ from jax.experimental.layout import DeviceLocalLayout as DLL # type: ignore from flax import linen as nn +from flax import nnx from flax import struct from flax.linen import partitioning as nn_partitioning import flax @@ -44,8 +45,10 @@ from maxtext.inference.page_manager import PageManager, PageState from maxtext.multimodal import processor as mm_processor from maxtext.utils import lora_utils +from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from maxtext.common.gcloud_stub import jetstream, is_decoupled from maxtext.common.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE @@ -112,11 +115,32 @@ def __init__(self, config: Any, devices: Any | None = None): self._mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) # Model and Optimizer definition. - # MaxEngine uses Linen-shaped state (state.params, state_mesh_shardings.params, - # state.opt_state) and serves Linen-format inference checkpoints. Use Linen path - # regardless of pure_nnx — the flag affects training, not inference serving. quant = quantizations.configure_quantization(config) - self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + if config.pure_nnx: + # We need both PREFILL and AR abstract models because the cache vars inherit + # CACHE_BATCH_PREFILL vs CACHE_BATCH from the construction model_mode, and + # bulk_insert searches for the substring "cache_batch" in the AR-mode names. + # Calling nnx.eval_shape directly (instead of create_nnx_abstract_model) avoids + # the jax.set_mesh wrap that trips Flax 0.12.6 on logical-only axes like "norm". + _create_model = model_creation_utils.get_nnx_create_model_fn(config, mesh=self._mesh, model_mode=MODEL_MODE_PREFILL) + _create_model_ar = model_creation_utils.get_nnx_create_model_fn( + config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_model = nnx.eval_shape(_create_model) + abstract_model_ar = nnx.eval_shape(_create_model_ar) + self.model = abstract_model + self.model_ar = abstract_model_ar + # 3-way split so JIT bodies can pass (params, cache, rest) separately to + # nnx.merge. `rest` (RNG state etc.) is materialized in load_params. + graphdef, _, _, _ = nnx.split(abstract_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._create_model_fn = _create_model + self._nnx_rest_state = None + else: + self.model = models.transformer_as_linen(config, mesh=self._mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + self.graphdef = None + self._create_model_fn = None self.replicated_sharding = jax.sharding.NamedSharding(self._mesh, P(None)) self.abstract_params = None @@ -142,6 +166,65 @@ def print_stats(self, label: str): max_utils.print_mem_stats(label) max_utils.print_cpu_ram_stats(label) + # NNX cache adapter: bulk_insert / _insert_jit / _maybe_stack_* switch on + # path[-1].key (e.g. "cached_prefill_key"). NNX state would expose ".value" at + # that position, so we convert NNX state <-> plain dict at the JIT boundary + # via to_pure_dict / replace_by_pure_dict. The cache helpers stay unchanged. + + def _nnx_cache_state_template(self, mode: str = MODEL_MODE_PREFILL) -> Any: + """Empty nnx.State template for the model's nnx.Cache vars (PREFILL=batch 1, AR=batch N).""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + return cache_state + + def _nnx_init_cache_dict(self, mode: str = MODEL_MODE_PREFILL) -> dict: + """Zero-filled pure-dict cache matching the abstract NNX model.""" + src = self.model if mode == MODEL_MODE_PREFILL else self.model_ar + _, cache_state, _ = nnx.split(src, nnx.Cache, ...) + cache_dict = cache_state.to_pure_dict() + return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict) + + def _nnx_run_model( + self, + params, + cache_dict, + decoder_input_tokens, + decoder_positions, + *, + decoder_segment_ids=None, + enable_dropout=False, + model_mode, + previous_chunk=None, + true_length=None, + slot=None, + page_state=None, + encoder_images=None, + encoder_image_masks=None, + encoder_audios=None, + ): + """NNX equivalent of `model.apply(..., mutable=["cache"])`. Returns (logits, new_cache_dict).""" + cache_state = self._nnx_cache_state_template(mode=model_mode) + nnx.replace_by_pure_dict(cache_state, cache_dict) + # copy=True avoids reusing Variable objects across traces (TraceContextError), + # mirroring the workaround in train.py's diff_wrapper. + model = nnx.merge(self.graphdef, params, cache_state, self._nnx_rest_state, copy=True) + logits = model( + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + encoder_images=encoder_images, + encoder_image_masks=encoder_image_masks, + encoder_audios=encoder_audios, + enable_dropout=enable_dropout, + model_mode=model_mode, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_cache = nnx.state(model, nnx.Cache).to_pure_dict() + return logits, new_cache + def generate_aot( self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None ): # returns (new_decode_state, result_tokens) @@ -225,6 +308,9 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + return self._load_params_nnx(params=params, rng=rng) + if self.model.quant and self.config.checkpoint_is_quantized: print("Loading from the quantized checkpoint...") self.model.quant.quant_mode = quantizations.get_quant_mode("serve") @@ -284,11 +370,80 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar return params + def _load_params_nnx(self, params, rng): + """NNX equivalent of load_params: returns an nnx.Param state and populates KV cache shardings.""" + if self.model.quant is not None: + raise NotImplementedError("pure_nnx + quantization not yet supported. Use pure_nnx=False.") + + if params: + print("Resharding given NNX params") + _, params_abs, _ = nnx.split(self.model, nnx.Param, ...) + target_shardings = jax.tree.map( + lambda x: x.sharding if hasattr(x, "sharding") else None, + params_abs, + is_leaf=lambda x: isinstance(x, nnx.Variable), + ) + params_state = jax.device_put(params, target_shardings) + # Build a concrete model once to capture a real `rest` (RNG vars) for nnx.merge. + # Wasteful but simple — the from_pretrained branch below avoids this. + with nn_partitioning.axis_rules(self.config.logical_axis_rules): + concrete_model = self._create_model_fn() + graphdef, _, _, rest_state = nnx.split(concrete_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del concrete_model + else: + max_logging.log("Loading NNX params via from_pretrained") + with self._mesh: + nnx_model = model_creation_utils.from_pretrained( + self.config, mesh=self._mesh, model_mode=MODEL_MODE_AUTOREGRESSIVE + ) + # Refresh graphdef from the concrete loaded model so subsequent merges line up. + graphdef, params_state, _, rest_state = nnx.split(nnx_model, nnx.Param, nnx.Cache, ...) + self.graphdef = graphdef + self._nnx_rest_state = rest_state + del nnx_model + + self.abstract_params = jax.tree.map( + lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding) + if isinstance(x, jax.Array) + else None, + params_state, + ) + + self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations_nnx( + self.model, self.config, self._mesh + ) + self.prefill_kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.prefill_kv_cache_annotations, + ) + if self.config.stack_prefill_result_cache: + # With scan_layers=True the NNX cache leaves are already stacked on axis 0, + # so the engine's manual-stack helper (which assumes an unstacked Linen tree) + # doesn't apply. Wiring this up cleanly is a Phase-2 follow-up. + raise NotImplementedError("pure_nnx + stack_prefill_result_cache=True not yet supported.") + # AR-mode abstract model so axis names use CACHE_BATCH (not CACHE_BATCH_PREFILL); + # bulk_insert / _insert_jit search for "cache_batch" in the per-leaf logical axes. + self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations_nnx(self.model_ar, self.config, self._mesh) + self.kv_cache_shardings = jax.tree.map( + lambda x: jax.sharding.NamedSharding(self._mesh, x), + self.kv_cache_annotations, + ) + # state_mesh_annotations is unused on the NNX path; callers reading it + # (e.g. set_engine_vars_from_base_engine) need to be NNX-aware first. + self.state_mesh_annotations = None + + self.print_stats("After load_params (NNX)") + return params_state + def load_single_adapter(self, adapter_path): """ Load Single adapter from adapter_path. Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.") adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -324,6 +479,8 @@ def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" if rng is None: rng = jax.random.PRNGKey(0) + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + quantize_params not yet supported.") self.model.quant.quant_mode = quantizations.get_quant_mode("convert") @@ -478,7 +635,10 @@ def _prefill_jit( if existing_prefix is not None: if not self.use_chunked_prefill: raise ValueError("Using chunked prefill is needed for existing_prefix.") - input_params = params | {"cache": existing_prefix.cache} + # NNX threads existing_prefix.cache via the nnx_cache local below; only + # the Linen path merges cache into input_params (params is a dict there). + if not self.config.pure_nnx: + input_params = params | {"cache": existing_prefix.cache} start_position = existing_prefix.common_prefix_tokens.shape[0] # TODO(yuyanpeng): rename previous_chunk previous_chunk = jnp.expand_dims(existing_prefix.common_prefix_tokens, 0) @@ -510,24 +670,48 @@ def _prefill_jit( sequence_indicator = jnp.expand_dims(one_d_output, 0) rng, new_rng = jax.random.split(rng) - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - flat_logits, new_vars = self.model.apply( - input_params, - input_tokens, - positions, - encoder_images=images, - encoder_image_masks=image_masks, - encoder_audios=audio_values, - decoder_segment_ids=sequence_indicator, - enable_dropout=False, - model_mode=MODEL_MODE_PREFILL, - rngs={"params": new_rng}, - mutable=["cache"], - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, + if self.config.pure_nnx: + # Prefill always operates on batch=1 (one padded prompt at a time). + nnx_cache = ( + existing_prefix.cache if existing_prefix is not None else self._nnx_init_cache_dict(mode=MODEL_MODE_PREFILL) ) + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_cache_dict = self._nnx_run_model( + params=input_params, + cache_dict=nnx_cache, + decoder_input_tokens=input_tokens, + decoder_positions=positions, + decoder_segment_ids=sequence_indicator, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + flat_logits, new_vars = self.model.apply( + input_params, + input_tokens, + positions, + encoder_images=images, + encoder_image_masks=image_masks, + encoder_audios=audio_values, + decoder_segment_ids=sequence_indicator, + enable_dropout=False, + model_mode=MODEL_MODE_PREFILL, + rngs={"params": new_rng}, + mutable=["cache"], + previous_chunk=previous_chunk, + true_length=true_length, + slot=slot, + page_state=page_state, + ) if return_prompt_logp: prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length) else: @@ -736,6 +920,9 @@ def _prefill_multisampling_jit( prefilling stage. The number of tokens is specified by num_samples. """ + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_multisampling not yet supported. Use pure_nnx=False.") + input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] positions = jnp.expand_dims(jnp.arange(0, input_tokens.shape[1]), 0) @@ -861,6 +1048,9 @@ def prefill_concat( if existing_prefix: raise ValueError("We don't know what to do with existing_prefix") + if self.config.pure_nnx: + raise NotImplementedError("pure_nnx + prefill_concat not yet supported. Use pure_nnx=False.") + if rng is None: rng = jax.random.PRNGKey(0) input_tokens = jnp.expand_dims(padded_tokens, 0) # [BATCH, SEQUENCE] @@ -1030,17 +1220,30 @@ def _generate_jit( previous_token = decode_state["tokens"] rng, new_rng = jax.random.split(rng) # run one step generation - with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - out_logits, new_vars = self.model.apply( - params | {"cache": decode_state["cache"]}, - previous_token, - decode_state["next_pos"], - enable_dropout=False, - model_mode=MODEL_MODE_AUTOREGRESSIVE, - rngs={"params": new_rng}, - mutable=["cache"], - page_state=page_state, - ) + if self.config.pure_nnx: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_cache_dict = self._nnx_run_model( + params=params, + cache_dict=decode_state["cache"], + decoder_input_tokens=previous_token, + decoder_positions=decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + page_state=page_state, + ) + new_vars = {"cache": new_cache_dict} + else: + with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + out_logits, new_vars = self.model.apply( + params | {"cache": decode_state["cache"]}, + previous_token, + decode_state["next_pos"], + enable_dropout=False, + model_mode=MODEL_MODE_AUTOREGRESSIVE, + rngs={"params": new_rng}, + mutable=["cache"], + page_state=page_state, + ) out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding) new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings) # sampling tokens @@ -1598,6 +1801,9 @@ def init_decode_state( if self.config.attention == "paged" and self.page_manager is not None: page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error + if self.config.pure_nnx: + return self._init_decode_state_nnx(rng=rng, page_state=page_state) + # pylint: disable=unused-argument def init(abstract_params, page_state): x = jnp.ones( @@ -1691,6 +1897,51 @@ def is_lp(k): zeroed = max_utils.unbox_logicallypartioned(init_state) return zeroed + def _init_decode_state_nnx(self, rng, page_state) -> DecodeState: + """NNX equivalent of init_decode_state. Returns a decode_state dict with a pure-dict cache.""" + del rng, page_state # cache shape comes from the abstract model + batch = int(self.config.per_device_batch_size * self.mesh.size) + vocab = self.config.vocab_size + + # AR-mode cache so the batch dim matches generate's input shape. + cache_dict_abs = self._nnx_init_cache_dict(mode=MODEL_MODE_AUTOREGRESSIVE) + + @functools.partial(jax.jit, out_shardings=(self.kv_cache_shardings,)) + def _init_cache(): + return (jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), cache_dict_abs),) + + (cache,) = _init_cache() + + # Per-leaf logical axes for bulk_insert's "cache_batch" lookup. Use model_ar + # so segment_id leaves carry CACHE_BATCH (under PREFILL they'd carry + # CACHE_BATCH_PREFILL, which doesn't contain the "cache_batch" substring). + _, cache_state, _ = nnx.split(self.model_ar, nnx.Cache, ...) + + def _logical_axes_for(var): + # Flax 0.12.6 renamed "sharding" to "out_sharding"; older code may still + # use "sharding_names". Try all three. + meta = var.get_metadata() if hasattr(var, "get_metadata") else {} + out = meta.get("out_sharding") or meta.get("sharding") or meta.get("sharding_names") + if out is None: + return () + return (out,) if isinstance(out, str) else tuple(out) + + annotations_state = jax.tree.map( + _logical_axes_for, + cache_state, + is_leaf=lambda v: isinstance(v, nnx.Variable), + ) + self.kv_cache_annotations_named = annotations_state.to_pure_dict() + + return { + "logits": jnp.zeros((batch, 1, vocab), dtype=jnp.float32), + "cache": cache, + "next_pos": jnp.zeros((batch, 1), dtype=jnp.int32), + "generated_tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "tokens": jnp.zeros((batch, 1), dtype=jnp.int32), + "token_logp": jnp.zeros((batch, 1), dtype=jnp.float32), + } + @property def max_concurrent_decodes(self) -> int: """Free slots.""" diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 0f07f5c14d..6a182cd869 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1724,6 +1724,30 @@ def init_kv_cache(model, config): return state_mesh_annotations +def _nnx_cache_partition_specs(abstract_model, config, mesh): + """Per-leaf PartitionSpec tree for the abstract model's nnx.Cache vars. + + Returned as a pure dict so the engine can wrap it in NamedSharding the same + way it does for the Linen helpers below. + """ + _, cache_state, _ = nnx.split(abstract_model, nnx.Cache, ...) + # get_nnx_named_sharding_with_scan_axis reads logical axis rules from the + # active flax partitioning context, so wrap. + with nn_partitioning.axis_rules(config.logical_axis_rules): + named_state = get_nnx_named_sharding_with_scan_axis(cache_state, mesh) + return jax.tree.map(lambda s: s.spec, named_state.to_pure_dict()) + + +def get_prefill_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_prefill_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + +def get_kv_cache_annotations_nnx(abstract_model, config, mesh): + """NNX equivalent of get_kv_cache_annotations.""" + return _nnx_cache_partition_specs(abstract_model, config, mesh) + + def save_quantized_checkpoint_if_configured(config, params): """Save quantized checkpoint if configured""" assert config.quantization, "quantization must be configured" diff --git a/tests/integration/maxengine_test.py b/tests/integration/maxengine_test.py index eb4a7729d6..42730c2d76 100644 --- a/tests/integration/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -23,6 +23,8 @@ from jax.sharding import Mesh import numpy as np import pytest +from flax import nnx +from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL from maxtext.layers import quantizations @@ -30,7 +32,10 @@ pytest.importorskip("jetstream", reason="jetstream not installed") from maxtext.inference.maxengine import maxengine from maxtext.models import models +from maxtext.checkpoint_conversion import linen_nnx_converter +from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import model_creation_utils from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_serving] @@ -163,6 +168,171 @@ def test_basic_decode(self): self.assertEqual(result_token.data.ndim, 2) self.assertEqual(result_token.data.shape[1], 3) + def _init_nnx_pyconfig(self, **kwargs): + """init_pyconfig with NNX flags on.""" + return self.init_pyconfig(pure_nnx=True, enable_nnx=True, pure_nnx_decoder=True, **kwargs) + + def _build_nnx_params(self, cfg, mesh): + """Materialize an NNX Transformer and return its nnx.Param state.""" + _create_model = model_creation_utils.get_nnx_create_model_fn(cfg, mesh=mesh, model_mode=MODEL_MODE_PREFILL) + with nn_partitioning.axis_rules(cfg.logical_axis_rules): + model = _create_model() + _, params_state, _ = nnx.split(model, nnx.Param, ...) + return params_state + + def test_init_nnx(self): + """NNX engine init exposes graphdef + abstract Transformer.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + self.assertIsNotNone(engine.graphdef) + self.assertIsNotNone(engine.model) + self.assertEqual(type(engine.model).__name__, "Transformer") + + def test_basic_prefill_nnx(self): + """NNX prefill returns a Linen-shape result dict with finite values.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + prefill_result, first_token = engine.prefill(params=params, padded_tokens=input_tokens, true_length=true_length) + + self.assertEqual(prefill_result["generated_tokens"], jnp.array([0])) + self.assertEqual(prefill_result["tokens"].size, 1) + self.assertTrue(jnp.array_equal(first_token.data.size, 3)) + self.assertEqual(first_token.log_prob.shape, (1, 1)) + self.assertIn("cache", prefill_result) + self.assertIsInstance(prefill_result["cache"], dict) + # Catch silent NaN/inf from a bad nnx.merge or cache round-trip. + self.assertTrue(jnp.all(jnp.isfinite(prefill_result["logits"]))) + cache_leaves, _ = jax.tree.flatten(prefill_result["cache"]) + for leaf in cache_leaves: + self.assertTrue(jnp.all(jnp.isfinite(leaf)), msg=f"non-finite cache leaf, shape={leaf.shape}") + # scan_layers=True (default in test config) ⇒ leading axis is num_decoder_layers. + for leaf in cache_leaves: + self.assertEqual(leaf.shape[0], cfg.num_decoder_layers, msg=f"layer-axis mismatch, got shape={leaf.shape}") + + def test_basic_decode_nnx(self): + """NNX prefill → insert → 4 generate steps. Verifies next_pos advances and logits stay finite.""" + cfg = self._init_nnx_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg) + mesh = Mesh(devices_array, cfg.mesh_axes) + params_state = self._build_nnx_params(cfg, mesh) + + input_tokens = jnp.array([1, 306, 5360, 304]) + engine = maxengine.MaxEngine(cfg, jax.devices()) + params = engine.load_params(params=params_state) + decode_state = engine.init_decode_state() + prefill_result, _ = engine.prefill(params=params, padded_tokens=input_tokens, true_length=4) + decode_state = engine.insert(prefill_result, decode_state, slot=0) + + # 4 steps is enough to catch off-by-one cache pointer bugs. + initial_next_pos = int(decode_state["next_pos"][0, 0]) + for step in range(4): + decode_state, result_token = engine.generate(params=params, decode_state=decode_state) + self.assertEqual(result_token.log_prob.ndim, 2) + self.assertEqual(result_token.log_prob.shape[1], 1) + self.assertEqual(result_token.data.ndim, 2) + self.assertEqual(result_token.data.shape[1], 3) + self.assertTrue(jnp.all(jnp.isfinite(decode_state["logits"]))) + self.assertEqual( + int(decode_state["next_pos"][0, 0]), + initial_next_pos + step + 1, + msg=f"next_pos didn't advance at step {step}", + ) + + def test_quantize_raises_for_nnx(self): + """pure_nnx + quantization raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig(quantization="int8") + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_params(rng=self.rng) + + def test_lora_raises_for_nnx(self): + """pure_nnx + LoRA raises NotImplementedError.""" + cfg = self._init_nnx_pyconfig() + engine = maxengine.MaxEngine(cfg, jax.devices()) + with self.assertRaises(NotImplementedError): + engine.load_single_adapter("/nonexistent/adapter/path") + + def _linen_params_to_nnx_state(self, linen_params, abstract_nnx_model): + """Convert Linen params → NNX nnx.Param state via linen_nnx_converter so both engines share weights.""" + nnx_dict_wrapped = linen_nnx_converter.convert_linen_to_nnx({"params": linen_params}, scan_layers=True)["model"] + # pylint: disable=protected-access + nnx_pure = linen_nnx_converter._strip_value_wrappers(nnx_dict_wrapped) + _, params_state, _ = nnx.split(abstract_nnx_model, nnx.Param, ...) + nnx.replace_by_pure_dict(params_state, nnx_pure) + return params_state + + def test_linen_nnx_parity_prefill(self): + """Same weights → same prefill output across Linen and NNX engines. + + A failure here means the NNX forward pass diverges from Linen on identical + weights (cache plumbing, nnx.merge wiring, or Transformer.__call__). + """ + cfg_linen = self.init_pyconfig() + devices_array = maxtext_utils.create_device_mesh(cfg_linen) + mesh = Mesh(devices_array, cfg_linen.mesh_axes) + + # Linen: init params, run prefill. + quant = quantizations.configure_quantization(cfg_linen) + linen_model = models.transformer_as_linen(config=cfg_linen, mesh=mesh, quant=quant, model_mode=MODEL_MODE_PREFILL) + ids, decoder_segment_ids, decoder_positions = self.get_data() + linen_vars = linen_model.init( + {"params": self.rng, "aqt": self.rng, "dropout": self.rng}, + ids, + decoder_positions, + decoder_segment_ids, + enable_dropout=False, + ) + # Linen.init wraps leaves in LogicallyPartitioned (which has a `.value` + # attribute); unbox so the converter's {value:} wrapper detector doesn't + # mistake them for already-wrapped NNX leaves. + linen_vars = max_utils.unbox_logicallypartioned(linen_vars) + + input_tokens = jnp.array([1, 306, 5360, 304, 0, 0, 0, 0]) + true_length = 4 + linen_engine = maxengine.MaxEngine(cfg_linen, jax.devices()) + linen_params = linen_engine.load_params(params=linen_vars) + linen_prefill, linen_first_token = linen_engine.prefill( + params=linen_params, padded_tokens=input_tokens, true_length=true_length + ) + + # NNX: bridge Linen weights, run prefill on the same prompt. + cfg_nnx = self._init_nnx_pyconfig() + nnx_engine = maxengine.MaxEngine(cfg_nnx, jax.devices()) + nnx_params_state = self._linen_params_to_nnx_state(linen_vars["params"], nnx_engine.model) + nnx_params = nnx_engine.load_params(params=nnx_params_state) + nnx_prefill, nnx_first_token = nnx_engine.prefill( + params=nnx_params, padded_tokens=input_tokens, true_length=true_length + ) + + # Tolerance is loose because the test config uses bf16 compute, where + # accumulation order between Linen-scan and NNX-scan drifts by ~0.05. + # Greedy match below is the behavioral check that actually matters. + linen_logits = np.asarray(linen_prefill["logits"]) + nnx_logits = np.asarray(nnx_prefill["logits"]) + self.assertEqual(linen_logits.shape, nnx_logits.shape) + np.testing.assert_allclose( + linen_logits, + nnx_logits, + rtol=0.05, + atol=0.1, + err_msg="Linen vs NNX prefill logits diverge beyond bf16 tolerance.", + ) + self.assertEqual( + int(linen_first_token.data[0, 0]), + int(nnx_first_token.data[0, 0]), + msg="Linen and NNX disagreed on greedy first token with identical weights.", + ) + linen_cache_leaves, _ = jax.tree.flatten(linen_prefill["cache"]) + nnx_cache_leaves, _ = jax.tree.flatten(nnx_prefill["cache"]) + self.assertEqual(len(linen_cache_leaves), len(nnx_cache_leaves)) + @pytest.mark.skip(reason="Can only pass on CPU.") def test_chunked_prefill(self): """Test identical result between chunked prefill with single and multiple chunked. From 6c65652868b8a9f48ef943b5872a9dab3e8ef60a Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 6 May 2026 14:52:22 +0000 Subject: [PATCH 4/4] NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pure_nnx warning) --- src/maxtext/experimental/rl/grpo_trainer.py | 324 ++++++++++++++++--- src/maxtext/experimental/rl/grpo_utils.py | 73 ++++- src/maxtext/inference/maxengine/maxengine.py | 19 +- src/maxtext/utils/lora_utils.py | 224 +++++++++++-- tests/integration/maxengine_test.py | 10 +- tests/unit/grpo_nnx_test.py | 231 +++++++++++++ tests/unit/lora_utils_nnx_test.py | 293 +++++++++++++++++ 7 files changed, 1087 insertions(+), 87 deletions(-) create mode 100644 tests/unit/grpo_nnx_test.py create mode 100644 tests/unit/lora_utils_nnx_test.py diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 4244d199a8..b788ccd13e 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -54,9 +54,12 @@ from jax import random from flax.linen import partitioning as nn_partitioning +from flax import nnx from flax import struct from flax.nnx import TrainState +from maxtext.layers import train_state_nnx + from cloud_tpu_diagnostics import diagnostic from cloud_tpu_diagnostics.configuration import debug_configuration from cloud_tpu_diagnostics.configuration import diagnostic_configuration @@ -85,11 +88,12 @@ from maxtext.experimental.rl import grpo_utils from maxtext.common.metric_logger import MetricLogger from maxtext.common.vertex_tensorboard import VertexTensorboardManager -from maxtext.inference import offline_engine from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding from maxtext.utils import train_utils @@ -335,34 +339,190 @@ def grpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_ return loss, aux +def grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """GRPO loss for the NNX path. + + Signature matches the Linen `grpo_loss_fn` so callers can dispatch on the + same shape. `dropout_rng` and `params` are unused (NNX models carry these + themselves); `reference_model` is the frozen reference `nnx.Module`. The + reference forward is wrapped in `stop_gradient` so grads only flow into the + policy. Returns `(loss, LossAux)`. + """ + del dropout_rng, params # NNX models carry these themselves + + prompt_with_completions = data[f"{config.train_data_columns}_completions"] + prompt_completions_position = data[f"{config.train_data_columns}_completions_position"] + prompt_completions_segmentation = data[f"{config.train_data_columns}_completions_segmentation"] + completions_segmentation = data["ar_completions_segmentation"] + + token_logps_policy, intermediate_outputs = grpo_utils.compute_log_probs_nnx( + policy_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=is_train, + ) + + completion_target_segmentation = data["ar_completions_segmentation"][..., 1:] + valid_seq_mask = completion_target_segmentation != 0 + + rewards = grpo_utils.dummy_reward_len(valid_seq_mask) + rewards = jnp.array(rewards) + + G = config.num_generations + rewards_grouped = rewards.reshape(-1, G) + group_mean = jnp.mean(rewards_grouped, axis=1) + group_std = jnp.std(rewards_grouped, axis=1) + repeated_group_mean = jnp.repeat(group_mean, G) + repeated_group_std = jnp.repeat(group_std, G) + advantages = (rewards - repeated_group_mean) / (repeated_group_std + EPS) + advantages_exp = advantages[:, None] + + if data["completions_logprobs"] is None: # off-policy + old_per_token_logps = jax.lax.stop_gradient(token_logps_policy) + else: # on-policy + old_per_token_logps = data["completions_logprobs"] + + policy_diff = token_logps_policy - old_per_token_logps + coef_1 = jnp.exp(policy_diff) + coef_2 = jnp.clip(coef_1, 1 - config.grpo_epsilon, 1 + config.grpo_epsilon) + loss_tokens = -jnp.minimum(coef_1 * advantages_exp, coef_2 * advantages_exp) + + if config.grpo_beta != 0.0: + token_logps_ref, _ = grpo_utils.compute_log_probs_nnx( + reference_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=False, + ) + token_logps_ref = jax.lax.stop_gradient(token_logps_ref) + token_diff_logps_ref_policy = token_logps_ref - token_logps_policy + per_token_kl = jnp.exp(token_diff_logps_ref_policy) - token_diff_logps_ref_policy - 1 + per_token_kl = per_token_kl * valid_seq_mask + loss_tokens += config.grpo_beta * per_token_kl + + loss_per_example = jnp.sum(loss_tokens * valid_seq_mask, axis=1) / jnp.clip(jnp.sum(valid_seq_mask, axis=1), min=1) + loss = jnp.mean(loss_per_example) + total_weights = jnp.sum(valid_seq_mask) + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + + if config.grpo_beta != 0.0: + avg_kl = jnp.mean((per_token_kl * valid_seq_mask) / jnp.clip(jnp.sum(valid_seq_mask, axis=1, keepdims=True), min=1)) + else: + avg_kl = None + avg_completion_length = jnp.mean(jnp.sum(data["ar_completions_segmentation"] != 0, axis=1)) + aux = LossAux( + total_loss=loss, + avg_reward=jnp.mean(rewards), + avg_reward_std=jnp.mean(repeated_group_std), + avg_advantage=jnp.mean(advantages), + avg_kl=avg_kl, + completion_length=avg_completion_length, + moe_lb_loss=moe_lb_loss, + total_weights=total_weights, + ) + return loss, aux + + # ----------------------------------------------------------------------------- # Trainer and top level training functions # ----------------------------------------------------------------------------- -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """Performs a single training step of the GRPO algorithm. +def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): + """GRPO train_step body for the NNX path. - This function computes the GRPO loss, calculates gradients, and updates the - model's parameters. It handles gradient accumulation and clipping as configured. - The reference model's parameters are held constant during the update. + Reconstructs `TrainStateNNX` from `(model_graphdef, state)`, splits out + the policy params for value_and_grad, applies gradients, and returns the + new state with `nnx.Intermediate` filtered out (transient sown values + must not persist across steps). + """ + del state_mesh_shardings # host-offload paths not yet wired up here - Args: - model: The transformer model to be trained. - config: The training configuration object. - state_mesh_shardings: Pytree of sharding specifications for the training state. - params_shardings: Pytree of sharding specifications for the model parameters. - This argument is not used and is kept to match the signature of other trainers. - state: The current training state, including parameters and optimizer state. - data: A batch of training data, including prompts and generated completions. - dropout_rng: JAX PRNG key for dropout. + if config.gradient_accumulation_steps > 1: + raise NotImplementedError( + "GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. " + "Set gradient_accumulation_steps=1 or pure_nnx=False." + ) - Returns: - A tuple containing: - - new_state: The updated training state after applying gradients. - - metrics: A dictionary of metrics for logging, including loss, reward, - and gradient norms. + state = nnx.merge(model_graphdef, state) # reconstruct TrainStateNNX + policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(policy_graphdef, param, rest, copy=True) + loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + state.apply_gradients(grads) + new_state = state + + scalar_metrics = { + "learning/loss": loss, + "learning/avg_reward": aux.avg_reward, + "learning/avg_reward_std": aux.avg_reward_std, + "learning/avg_advantage": aux.avg_advantage, + "learning/avg_kl": aux.avg_kl, + "learning/completion_length": aux.completion_length, + "learning/moe_lb_loss": aux.moe_lb_loss, + "learning/total_weights": aux.total_weights, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + } + _, new_policy_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) + metrics = {"scalar": scalar_metrics, "scalars": {}} + + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics + + +def _eval_step_nnx(model_graphdef, config, state, data): + """GRPO eval_step body for the NNX path. No state update.""" + state = nnx.merge(model_graphdef, state) + loss, aux = grpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": aux.total_loss, + "evaluation/total_weights": aux.total_weights, + "evaluation/moe_lb_loss": aux.moe_lb_loss, + }, + } + return metrics + + +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): + """Single GRPO training step. + + Computes the GRPO loss, gradients, and applies them to the policy. The + reference model is held constant. Linen and NNX paths split below; on + NNX, `model` is a GraphDef and `state` is a flat `nnx.State` of a + `TrainStateNNX` (with `model`, `optimizer`, and `reference_model`). + + Returns `(new_state, metrics)`. """ + if config.pure_nnx: + return _train_step_nnx(model, config, state_mesh_shardings, state, data) + state, reference_params = _split_grpo_state(state) state_mesh_shardings, reference_params_sharding = _split_grpo_state(state_mesh_shardings) extra_grpo_args = [reference_params] @@ -473,6 +633,8 @@ def eval_step(model, config, state, data, dropout_rng): Returns: A dictionary of evaluation metrics. """ + if config.pure_nnx: + return _eval_step_nnx(model, config, state, data) reference_params, extra_grpo_args, _loss_fn = [], [], grpo_loss_fn state, reference_params = _split_grpo_state(state) @@ -542,28 +704,50 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ - # GRPO RL trainer is Linen-shaped end-to-end (state.params accesses below, - # state_mesh_shardings.params, and the inference path through MaxEngine which is - # Linen-only). Run on Linen path regardless of pure_nnx; warn the user since - # NNX-format checkpoints will mismatch at restore time. - if config.pure_nnx or config_inference.pure_nnx: - max_logging.log( - "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " - "running on the Linen path. NNX-format checkpoints will not load correctly here." + if config.pure_nnx != config_inference.pure_nnx: + raise ValueError( + f"config.pure_nnx ({config.pure_nnx}) and config_inference.pure_nnx " f"({config_inference.pure_nnx}) must agree." ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - model = mt.from_config(config, devices=training_devices) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + + if config.pure_nnx: + training_mesh = maxtext_utils.get_mesh_from_config(config, devices=training_devices) + training_rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = mt.from_config(config, devices=training_devices, mesh=training_mesh, rngs=training_rngs) + else: + model = mt.from_config(config, devices=training_devices) mesh = model.mesh + max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - inference_model = mt.from_config(config_inference, devices=inference_devices) + if config_inference.pure_nnx: + inference_mesh_obj = maxtext_utils.get_mesh_from_config(config_inference, devices=inference_devices) + inference_rngs = maxtext_utils_nnx.create_nnx_rngs(config_inference, rng_key=init_rng) + inference_model = mt.from_config( + config_inference, devices=inference_devices, mesh=inference_mesh_obj, rngs=inference_rngs + ) + else: + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + + if config.pure_nnx: + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh, devices=training_devices) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + # Reference uses the same init seed so it starts identical to the policy. + reference_model = _create_model_partial() + return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model) + + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -572,16 +756,29 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + if config_inference.pure_nnx: + _create_inference_partial, _ = model_creation_utils.create_nnx_abstract_model( + config_inference, inference_mesh, devices=inference_devices + ) + + def init_inference_state_fn(): + inference_nnx_model = _create_inference_partial() + return train_state_nnx.TrainStateNNX(inference_nnx_model, None) + + else: + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + if config.pure_nnx: + _, params_for_check, _ = nnx.split(state.model, nnx.Param, ...) + sharding.assert_params_sufficiently_sharded(params_for_check, mesh, config.sharding_tolerance) + else: + sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) return ( init_rng, @@ -694,10 +891,15 @@ def train_loop(config, config_inference, recorder, state=None): token=config.hf_access_token, ) - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_grpo_state(state, reference_params) - state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if config.pure_nnx: + # `reference_model` is set up by init_state_fn as a sibling field — nothing to merge. + if not hasattr(state, "reference_model"): + raise RuntimeError("NNX GRPO state is missing reference_model; check setup_train_loop.") + else: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_grpo_state(state, reference_params) + state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator @@ -705,6 +907,9 @@ def train_loop(config, config_inference, recorder, state=None): data_sharding = sharding.get_input_data_sharding(config, mesh) + # Lazy import: pulls in maxengine and jetstream stubs. + from maxtext.inference import offline_engine # pylint: disable=import-outside-toplevel + inference_engine = offline_engine.OfflineEngine( config=config_inference, mesh=inference_mesh, @@ -719,7 +924,11 @@ def train_loop(config, config_inference, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params["params"]) + if config.pure_nnx: + _, _params_for_metrics, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(_params_for_metrics) + else: + metric_logger.write_setup_info_to_tensorboard(state.params["params"]) def generation_worker_fn( worker_inference_engine, @@ -843,21 +1052,32 @@ def generation_worker_fn( state, metrics = p_train_step(state, example_batch, train_rng) with jax.profiler.StepTraceAnnotation("transfer data", step_num=step): if step != 0 and step % config.inference_rollouts == 0: - grpo_utils.pathways_reshard( - config_inference, - inference_engine, - {"params": state.params["params"]}, - {"params": state_mesh_shardings.params["params"]}, - mesh, - {"params": inference_state_mesh_shardings.params["params"]}, - ) + if config.pure_nnx: + grpo_utils.pathways_reshard_nnx( + config_inference, + inference_engine, + state.model, + state_mesh_shardings.model, + inference_state_mesh_shardings.model, + ) + else: + grpo_utils.pathways_reshard( + config_inference, + inference_engine, + {"params": state.params["params"]}, + {"params": state_mesh_shardings.params["params"]}, + mesh, + {"params": inference_state_mesh_shardings.params["params"]}, + ) with data_buffer_lock: data_buffer.clear() step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = _split_grpo_state(state)[0] + # Linen embeds reference in `state.params` and strips it for save; NNX + # holds it as a sibling field on TrainStateNNX so the whole state goes. + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == start_step: @@ -895,7 +1115,7 @@ def generation_worker_fn( metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = _split_grpo_state(state)[0] + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) elif checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 352a2b3b8d..34d437867e 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -21,8 +21,9 @@ import jaxtyping from typing import Any, Callable +from flax import nnx + from maxtext.common.common_types import DecoderBlockType -from maxtext.inference.offline_engine import InputData from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -112,6 +113,48 @@ def compute_log_probs( return token_log_probs, intermediate_outputs +def compute_log_probs_nnx( + model, + inputs, + inputs_position, + inputs_segmentation, + completion_segmentation, + config, + is_train=False, +): + """`compute_log_probs` for the NNX path. + + `model` is an `nnx.Module` (carries its own params + RNG state), so there's + no `params` arg. Intermediates are pulled off the model after the forward + via `nnx.state(model, nnx.Intermediate).to_pure_dict()`. + """ + logits = model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=(config.enable_dropout if is_train else False), + ) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + logits = logits / config.decode_sampling_temperature + + targets = inputs[:, 1:] + shifted_completion_segmentation = jax.lax.dynamic_slice( + completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1) + ) + shifted_completion_segmentation = jnp.pad( + shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0 + ) + mask = shifted_completion_segmentation[..., None] + mask = jnp.broadcast_to(mask, logits.shape) + masked_logits = jnp.where(mask, logits, -jnp.inf) + log_probs = jax.nn.log_softmax(masked_logits, axis=-1) + log_probs = jnp.where(mask, log_probs, -0.0) + log_probs = log_probs[:, :-1, :] + token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0] + token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1] + return token_log_probs, intermediate_outputs + + def generate_offline_completions(config, tokenizer_model, inference_engine, data): """Generates completions for a batch of prompts using an offline engine. @@ -125,6 +168,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data The input `data` dictionary updated with the generated completions, segmentations, positions, and log-probabilities. """ + # Lazy import: pulls in maxengine and jetstream stubs, which we only want to + # touch when this function is actually called (i.e. during a real GRPO run). + from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel + data[config.train_data_columns] = np.asarray( jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0) ) @@ -175,6 +222,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data return data +def pathways_reshard_nnx( + config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model +): + """`pathways_reshard` for the NNX path. + + Reshard the policy params onto the inference mesh and push them into the + inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet). + """ + if not config.scan_layers: + raise NotImplementedError( + "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." + ) + _, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...) + _, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...) + _, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...) + del source_param_shardings # already encoded on policy_params + with ( + jax.transfer_guard_device_to_host("disallow_explicit"), + jax.transfer_guard_host_to_device("disallow_explicit"), + ): + resharded_params = reshard_pytree(policy_params, dest_param_shardings) + inference_engine.update_params(resharded_params) + + def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): """Reshards model parameters from training to inference sharding. diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 5bd220f4e1..d270a1e23f 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -438,12 +438,11 @@ def _load_params_nnx(self, params, rng): return params_state def load_single_adapter(self, adapter_path): + """Load a single LoRA adapter from `adapter_path`. + + Expects `adapter_config.json` plus adapter weights at `/0/items`. + The returned `params` shape matches `self.abstract_params` (NNX or Linen). """ - Load Single adapter from adapter_path. - Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. - """ - if self.config.pure_nnx: - raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.") adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -466,14 +465,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params): lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) def unapply_adapter(self, base_params, adapter_config, adapter_params): """Unapply the adapter params from the merged params to get back the base params.""" lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index 1efad6aa91..4193a613d4 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,7 +14,7 @@ """Common LoRA utils needed to support LoRA adapters.""" - +from collections.abc import Mapping from functools import partial import json import os @@ -38,6 +38,10 @@ from maxtext.utils import sharding from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +# NNX-only imports (`flax.nnx`, `train_state_nnx`, `model_creation_utils`) are +# loaded lazily inside the NNX dispatch branches so the Linen-only flow doesn't +# pull them in. + def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): """ @@ -118,8 +122,10 @@ def unapply_lora_recursively(base_params, lora_params, module_name): def load_adapter(config, base_abstract_state_params, adapter_config_path, adapter_weights_path): - """ - Load the LoRA weights into a PyTree and return it. + """Load LoRA weights into a PyTree and return it. + + On the NNX path, `base_abstract_state_params` and the returned `lora_params` + are `nnx.State`-shaped (no outer `{"params": ...}` wrap). """ # Load LoRA weights lora_params = None @@ -137,7 +143,10 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte if not gcs_utils.gcs_path_exists(f"{adapter_weights_path}/commit_success.txt"): raise FileNotFoundError(f"Failed to read lora_weights from {adapter_weights_path}.") - lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) + if config.pure_nnx: + lora_state, _ = get_lora_abstract_state_nnx(base_abstract_state_params, lora_config) + else: + lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) with nn_partitioning.axis_rules(config.logical_axis_rules): lora_params = checkpointing.load_params_from_path( @@ -152,22 +161,12 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path): - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """Initialize the LoRA train state and optionally load weights from disk. - Args: - model: the flax model to initialize - tx: the optax.GradientTransformation - config: config object - rng: jax.prng key - mesh: jax.devices() mesh - checkpoint_manager: an Orbax checkpointing.CheckpointManager object - lora_adapter_path: Path of the LoRA adapter which is expected to have - `adapter_config.json` and adapter weights - - Returns: - state: the initialized train state - state_mesh_annotations: the mesh annotations for the train state + Returns `(lora_config, lora_state, lora_state_annotations)`. On the NNX path + `model` is unused (the NNX abstract state is built via + `model_creation_utils.create_nnx_abstract_model`) and `lora_state.params` + is `nnx.State`-shaped; on Linen it is the original `{"params": ...}` tree. """ lora_state = None @@ -176,17 +175,32 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - # LoRA adapters are Linen-format on disk (downstream `get_lora_abstract_state` expects - # `unboxed_abstract_state.params` Linen tree shape; `lora_state.replace(params=...)` - # uses Linen TrainState API). Use the Linen init path regardless of the pure_nnx flag. - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + if config.pure_nnx: + # pylint: disable=import-outside-toplevel + from maxtext.layers import train_state_nnx + from maxtext.utils import model_creation_utils + + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" lora_config = gcs_utils.read_json_from_gcs(lora_config_path) - lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) + if config.pure_nnx: + base_abstract_params = _nnx_param_subtree(unboxed_abstract_state) + lora_state, lora_state_annotations = get_lora_abstract_state_nnx(base_abstract_params, lora_config) + else: + lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) lora_weights_path = f"{lora_adapter_path}/0/items" @@ -606,3 +620,165 @@ def _map_to_state(path, variable): nnx.update(trainer.model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") return trainer + + +# NNX-shaped LoRA helpers. The Linen walkers above key on `isinstance(x, dict)` +# and bare leaves; NNX trees use `nnx.State` (Mapping but not dict) and +# Variable-wrapped leaves, so we need separate mirrors. The math (W += B @ A * s) +# is identical. + + +def _is_nnx_branch(x): + return isinstance(x, Mapping) + + +def _nnx_param_subtree(unboxed_abstract_state): + """Drop the outer TrainStateNNX wrapping and return the model substate.""" + return unboxed_abstract_state["model"] if "model" in unboxed_abstract_state else unboxed_abstract_state + + +def apply_lora_on_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX variant of `apply_lora_on_base_params`. Mutates `base_params` in place.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight + jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def unapply_lora_from_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """NNX-shaped variant of `unapply_lora_from_base_params`. Mutates `base_params`.""" + + def lora_update_or_base(base_weight, lora_a, lora_b): + if lora_a is not None and lora_b is not None: + return base_weight - jnp.einsum("br,rnd->bnd", lora_b, lora_a) * lora_scale_factor + return base_weight + + def recurse(base_node, lora_node, path): + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + if name not in ("lora_a.kernel", "lora_b.kernel"): + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + lora_b = lora_node["lora_a.kernel"] + lora_a = lora_node["lora_b.kernel"] + base_node["kernel"] = lora_update_or_base(base_node["kernel"], lora_a, lora_b) + return + + recurse(base_params, lora_params, "") + + +def get_lora_abstract_state_nnx(base_abstract_params, lora_config): + """`get_lora_abstract_state` for the NNX path. + + Walks the abstract `state.model` substate and emits a parallel tree with + `lora_a.kernel` / `lora_b.kernel` leaves at target attention paths and + `None` elsewhere. + """ + other_lora_format_to_jax_format = { + "q_proj": "self_attention.query", + "k_proj": "self_attention.key", + "v_proj": "self_attention.value", + "o_proj": "self_attention.out", + } + + lora_target_modules = [other_lora_format_to_jax_format.get(s, s) for s in lora_config["target_modules"]] + lora_rank = int(lora_config["r"]) + + def get_lora_param_shape(base_array_shape, lora_module): + if len(base_array_shape) > 4: + raise ValueError(f"Unsupported base array shape {base_array_shape} (>4D)") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_shape = base_array_shape[:-2] + (lora_rank,) + lora_b_shape = (lora_rank,) + base_array_shape[1:] + elif lora_module == "self_attention.out": + lora_a_shape = base_array_shape[:-1] + (lora_rank,) + if len(base_array_shape) == 4: + lora_b_shape = (lora_rank, base_array_shape[1], base_array_shape[-1]) + else: + lora_b_shape = (lora_rank, base_array_shape[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + return lora_a_shape, lora_b_shape + + def get_lora_param_sharding(base_param_sharding, lora_module): + if base_param_sharding is None: + return None, None + base_pspec = base_param_sharding.spec + if len(base_pspec) > 4: + raise ValueError("PartitionSpec size > 4 not supported") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-2] + ((),))) + lora_b_pspec = jax.sharding.PartitionSpec(*(((),) + base_pspec[1:])) + elif lora_module == "self_attention.out": + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-1] + ((),))) + if len(base_pspec) == 4: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[1], base_pspec[-1]) + else: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + mesh = base_param_sharding.mesh + mem_kind = base_param_sharding.memory_kind + return ( + jax.sharding.NamedSharding(mesh=mesh, spec=lora_a_pspec, memory_kind=mem_kind), + jax.sharding.NamedSharding(mesh=mesh, spec=lora_b_pspec, memory_kind=mem_kind), + ) + + def module_is_target(module_path): + for tgt in lora_target_modules: + if tgt in module_path: + return tgt + return None + + def add_lora(out_node, base_node, path): + for name, child in base_node.items(): + if _is_nnx_branch(child): + out_node[name] = {} + add_lora(out_node[name], child, f"{path}.{name}") + else: + if name not in ("kernel", "scale", "embedding"): + raise ValueError(f"Unexpected key={name} in base abstract params at {path}") + if not isinstance(child, jax.ShapeDtypeStruct): + raise ValueError(f"Unexpected leaf type {type(child).__name__} at {path}.{name}") + target_module = module_is_target(path) + if target_module is not None: + a_shape, b_shape = get_lora_param_shape(child.shape, target_module) + a_sharding, b_sharding = get_lora_param_sharding(child.sharding, target_module) + out_node["lora_a.kernel"] = jax.ShapeDtypeStruct(shape=a_shape, dtype=child.dtype, sharding=a_sharding) + out_node["lora_b.kernel"] = jax.ShapeDtypeStruct(shape=b_shape, dtype=child.dtype, sharding=b_sharding) + else: + out_node[name] = None + + lora_abstract_params = {} + add_lora(lora_abstract_params, base_abstract_params, "") + + unboxed_abstract_lora_state = train_state.TrainState( + step=0, apply_fn=None, params=lora_abstract_params, tx=None, opt_state={} # type: ignore + ) + lora_state_mesh_annotations = train_state.TrainState( + step=0, + apply_fn=None, + params=jax.tree_util.tree_map( + lambda x: x.sharding.spec if x.sharding is not None else None, + lora_abstract_params, + ), + tx=None, # type: ignore + opt_state={}, + ) + return unboxed_abstract_lora_state, lora_state_mesh_annotations diff --git a/tests/integration/maxengine_test.py b/tests/integration/maxengine_test.py index 42730c2d76..0f3329cf35 100644 --- a/tests/integration/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -252,11 +252,15 @@ def test_quantize_raises_for_nnx(self): with self.assertRaises(NotImplementedError): engine.load_params(rng=self.rng) - def test_lora_raises_for_nnx(self): - """pure_nnx + LoRA raises NotImplementedError.""" + def test_lora_load_single_adapter_reaches_loader_on_nnx(self): + """pure_nnx + LoRA: load_single_adapter dispatches to the NNX loader. + + With a nonexistent path the loader raises FileNotFoundError (not + NotImplementedError, which would mean the dispatch never reached the loader). + """ cfg = self._init_nnx_pyconfig() engine = maxengine.MaxEngine(cfg, jax.devices()) - with self.assertRaises(NotImplementedError): + with self.assertRaises(FileNotFoundError): engine.load_single_adapter("/nonexistent/adapter/path") def _linen_params_to_nnx_state(self, linen_params, abstract_nnx_model): diff --git a/tests/unit/grpo_nnx_test.py b/tests/unit/grpo_nnx_test.py new file mode 100644 index 0000000000..6f72c43723 --- /dev/null +++ b/tests/unit/grpo_nnx_test.py @@ -0,0 +1,231 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for `grpo_loss_fn_nnx`, `compute_log_probs_nnx`, plus a small +Linen-path regression block (the repo's existing Linen GRPO integration test +is TPU-only).""" + +import types +import unittest + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx + +from maxtext.experimental.rl import grpo_trainer +from maxtext.experimental.rl import grpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX module that responds to the kwargs `compute_log_probs_nnx` uses.""" + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_grpo_config(**overrides): + """Minimal config namespace covering every field `grpo_loss_fn_nnx` reads.""" + base = { + "train_data_columns": "prompt", + "num_generations": 2, + "grpo_epsilon": 0.2, + "grpo_beta": 0.1, + "num_experts": 1, + "decode_sampling_temperature": 1.0, + "enable_dropout": False, + "use_dpo": False, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_grpo_batch(B=2, G=2, S=6): + """Minimal GRPO batch: `B` prompts, `G` generations each (total `B*G`), seq length `S`.""" + total = B * G + prompts = jnp.tile(jnp.arange(S, dtype=jnp.int32), (total, 1)) + return { + "prompt_completions": prompts, + "prompt_completions_position": prompts, + "prompt_completions_segmentation": jnp.ones((total, S), dtype=jnp.int32), + "ar_completions_segmentation": jnp.array([[0, 0, 1, 1, 1, 0]] * total, dtype=jnp.int32), + "completions_logprobs": None, # off-policy + } + + +class TestGrpoLossFnNnx(unittest.TestCase): + """Behavior of `grpo_loss_fn_nnx` on a synthetic policy / reference pair.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) # identical seed + self.config = _make_grpo_config() + self.data = _make_grpo_batch() + + def test_aux_structure_matches_linen(self): + """`grpo_loss_fn_nnx` returns the same `LossAux` dataclass shape as `grpo_loss_fn`.""" + loss, aux = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, None, None, self.reference, is_train=True + ) + self.assertIsInstance(aux, grpo_trainer.LossAux) + for field in ( + "total_loss", + "avg_reward", + "avg_reward_std", + "avg_advantage", + "completion_length", + "moe_lb_loss", + "total_weights", + ): + self.assertTrue(hasattr(aux, field), f"aux missing field {field}") + self.assertTrue(jnp.isfinite(loss)) + + def test_unused_dropout_rng_and_params_args_are_ignored(self): + """`dropout_rng` and `params` are positional placeholders only — values shouldn't matter.""" + a = grpo_trainer.grpo_loss_fn_nnx(self.policy, self.config, self.data, None, None, self.reference, is_train=True) + b = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, jax.random.key(99), {"junk": 1}, self.reference, is_train=True + ) + np.testing.assert_allclose(np.asarray(a[0]), np.asarray(b[0]), rtol=1e-6) + + def test_identical_policy_and_reference_zero_kl(self): + """Identical policy and reference → per-token KL is zero, so `aux.avg_kl ≈ 0`.""" + cfg = _make_grpo_config(grpo_beta=0.5) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNotNone(aux.avg_kl) + np.testing.assert_allclose(np.asarray(aux.avg_kl), 0.0, atol=1e-5) + + def test_grpo_beta_zero_avg_kl_is_none(self): + cfg = _make_grpo_config(grpo_beta=0.0) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNone(aux.avg_kl) + + def test_value_and_grad_flows_only_to_policy(self): + """`nnx.value_and_grad` over the policy yields finite grads; reference is left alone.""" + + def loss_only(policy_model): + loss, _ = grpo_trainer.grpo_loss_fn_nnx( + policy_model, self.config, self.data, None, None, self.reference, is_train=True + ) + return loss + + # nnx.value_and_grad returns (value, grad_state) where grad_state holds nnx.Param leaves. + _, grads = nnx.value_and_grad(loss_only, argnums=0)(self.policy) + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(np.all(np.isfinite(np.asarray(leaf))), "policy grad has non-finite entries") + + +class TestComputeLogProbsNnx(unittest.TestCase): + """Shape contract of `compute_log_probs_nnx`.""" + + def test_returns_correct_shape(self): + config = _make_grpo_config() + data = _make_grpo_batch() + model = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + log_probs, _ = grpo_utils.compute_log_probs_nnx( + model, + data["prompt_completions"], + data["prompt_completions_position"], + data["prompt_completions_segmentation"], + data["ar_completions_segmentation"], + config, + is_train=False, + ) + # Inputs are [B, S] → log_probs are [B, S-1]. + self.assertEqual(log_probs.shape, (data["prompt_completions"].shape[0], data["prompt_completions"].shape[1] - 1)) + + +# --------------------------------------------------------------------------- +# Linen-path regression smoke tests +# --------------------------------------------------------------------------- + + +class _MockLinenTransformer(nn.Module): + """Tiny Linen module that responds to the same `model.apply(...)` shape Linen `compute_log_probs` uses.""" + + vocab_size: int + embed_dim: int + + @nn.compact + def __call__(self, inputs, positions, decoder_segment_ids=None, enable_dropout=False): + del positions, decoder_segment_ids, enable_dropout + embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim, name="embed")(inputs) + return nn.Dense(features=self.vocab_size, name="proj")(embed) + + +class TestLinenGrpoRegression(unittest.TestCase): + """Smoke test that the Linen `grpo_loss_fn` and `compute_log_probs` still run + end-to-end with `pure_nnx=False`-style inputs.""" + + def setUp(self): + self.config = _make_grpo_config() + self.config.pure_nnx = False # explicit Linen mode + self.config.gradient_accumulation_steps = 1 + self.data = _make_grpo_batch() + self.model = _MockLinenTransformer(vocab_size=8, embed_dim=4) + rng = jax.random.key(0) + inputs = self.data["prompt_completions"] + self.params = self.model.init(rng, inputs, inputs, decoder_segment_ids=jnp.ones_like(inputs), enable_dropout=False) + self.reference_params = jax.tree_util.tree_map(jnp.copy, self.params) + + def test_linen_grpo_loss_fn_still_runs(self): + """Linen `grpo_loss_fn` returns a finite loss + a `LossAux`.""" + loss, aux = grpo_trainer.grpo_loss_fn( + self.model, + self.config, + self.data, + jax.random.key(1), + self.params, + self.reference_params["params"], # Linen reference_params is the inner subtree + is_train=True, + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertTrue(hasattr(aux, "total_loss")) + self.assertTrue(hasattr(aux, "moe_lb_loss")) + self.assertTrue(hasattr(aux, "total_weights")) + + def test_linen_compute_log_probs_still_runs(self): + """Linen `compute_log_probs` produces shape `[B, S-1]`.""" + log_probs, _ = grpo_utils.compute_log_probs( + self.model, + self.params, + self.data["prompt_completions"], + self.data["prompt_completions_position"], + self.data["prompt_completions_segmentation"], + self.data["ar_completions_segmentation"], + self.config, + is_train=False, + rngs={"dropout": jax.random.key(2), "params": jax.random.key(3)}, + ) + S = self.data["prompt_completions"].shape[1] + self.assertEqual(log_probs.shape, (self.data["prompt_completions"].shape[0], S - 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/lora_utils_nnx_test.py b/tests/unit/lora_utils_nnx_test.py new file mode 100644 index 0000000000..e0e8cbb529 --- /dev/null +++ b/tests/unit/lora_utils_nnx_test.py @@ -0,0 +1,293 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-shaped LoRA helpers in `lora_utils`, plus a small +Linen regression block.""" + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +from maxtext.utils.lora_utils import ( + apply_lora_on_base_params, + apply_lora_on_base_params_nnx, + get_lora_abstract_state_nnx, + unapply_lora_from_base_params, + unapply_lora_from_base_params_nnx, +) + + +# --------------------------------------------------------------------------- +# Fake abstract state builders (mirror the NNX vs. Linen tree shapes) +# --------------------------------------------------------------------------- + + +def _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Tiny NNX-shaped abstract state for one attention block.""" + + def _sds(shape): + return jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=None) + + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": _sds((emb, num_heads, head_dim))}, + "key": {"kernel": _sds((emb, num_heads, head_dim))}, + "value": {"kernel": _sds((emb, num_heads, head_dim))}, + "out": {"kernel": _sds((emb, num_heads, head_dim))}, + }, + "mlp": {"wi": {"kernel": _sds((emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": _sds((100, emb))}, + }, + } + + +def _make_linen_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Linen-shaped equivalent (with the `{"params": ...}` outer wrap).""" + return {"params": _make_nnx_attention_abstract(emb, num_heads, head_dim, dtype)} + + +def _lora_config(rank=4, alpha=8.0, target_modules=("q_proj", "v_proj")): + return { + "r": rank, + "lora_alpha": alpha, + "target_modules": list(target_modules), + } + + +# --------------------------------------------------------------------------- +# get_lora_abstract_state_nnx +# --------------------------------------------------------------------------- + + +class TestGetLoraAbstractStateNnx(unittest.TestCase): + """`get_lora_abstract_state_nnx` shape, sharding, and error-path coverage.""" + + def test_lora_shapes_for_query_and_value(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=4)) + attn = state.params["decoder"]["layers"]["self_attention"] + + a = attn["query"]["lora_a.kernel"] + b = attn["query"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + self.assertEqual(a.dtype, jnp.float32) + self.assertEqual(b.dtype, jnp.float32) + + a = attn["value"]["lora_a.kernel"] + b = attn["value"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + + def test_non_target_modules_emit_none_leaves(self): + abs_params = _make_nnx_attention_abstract() + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(target_modules=("q_proj",))) + attn = state.params["decoder"]["layers"]["self_attention"] + self.assertIn("lora_a.kernel", attn["query"]) + self.assertIsNone(attn["key"]["kernel"]) + self.assertIsNone(attn["value"]["kernel"]) + self.assertIsNone(attn["out"]["kernel"]) + self.assertIsNone(state.params["decoder"]["layers"]["mlp"]["wi"]["kernel"]) + self.assertIsNone(state.params["decoder"]["shared_embedding"]["embedding"]) + + def test_o_proj_has_distinct_shape(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=3, target_modules=("o_proj",))) + out = state.params["decoder"]["layers"]["self_attention"]["out"] + a = out["lora_a.kernel"] + b = out["lora_b.kernel"] + # 3D base (emb, num_heads, head_dim) → lora_a.shape = (..., r), lora_b = (r, last) + self.assertEqual(a.shape, (8, 2, 3)) + self.assertEqual(b.shape, (3, 4)) + + def test_unsupported_leaf_type_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"kernel": jnp.zeros((4, 2, 2))}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + def test_unexpected_leaf_name_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"weight": jax.ShapeDtypeStruct((4, 2), jnp.float32)}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + # Linen-vs-NNX numerical parity is covered by TestApplyLoraNnx.test_numerical_parity_with_linen_apply. + + +# --------------------------------------------------------------------------- +# apply / unapply on NNX-shape pure dicts +# --------------------------------------------------------------------------- + + +def _concrete_base(rng=None, emb=4, num_heads=2, head_dim=3): + """Concrete arrays mirroring the abstract structure used above (NNX-shape).""" + if rng is None: + rng = jax.random.key(0) + k1, k2, k3, k4, k5, k6 = jax.random.split(rng, 6) + shape_attn = (emb, num_heads, head_dim) + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": jax.random.normal(k1, shape_attn)}, + "key": {"kernel": jax.random.normal(k2, shape_attn)}, + "value": {"kernel": jax.random.normal(k3, shape_attn)}, + "out": {"kernel": jax.random.normal(k4, shape_attn)}, + }, + "mlp": {"wi": {"kernel": jax.random.normal(k5, (emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": jax.random.normal(k6, (100, emb))}, + }, + } + + +def _build_lora_params(base, lora_config_dict, rng): + """Build a concrete LoRA tree (random arrays) matching `base`.""" + abs_tree = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=None), base) + lora_state, _ = get_lora_abstract_state_nnx(abs_tree, lora_config_dict) + + def _to_concrete(leaf, rng_key): + if leaf is None: + return None + return jax.random.normal(rng_key, leaf.shape, leaf.dtype) + + leaves, tree = jax.tree_util.tree_flatten(lora_state.params, is_leaf=lambda x: x is None) + rngs = jax.random.split(rng, max(1, len(leaves))) + out_leaves = [_to_concrete(l, r) for l, r in zip(leaves, rngs)] + return jax.tree_util.tree_unflatten(tree, out_leaves) + + +class TestApplyLoraNnx(unittest.TestCase): + """`apply_lora_on_base_params_nnx` round-trip and Linen-vs-NNX parity.""" + + def test_apply_then_unapply_is_identity(self): + rng = jax.random.key(42) + base_orig = _concrete_base(rng) + base = jax.tree_util.tree_map(jnp.copy, base_orig) + lora = _build_lora_params(base, _lora_config(rank=2, target_modules=("q_proj", "v_proj")), jax.random.key(7)) + apply_lora_on_base_params_nnx(base, lora, lora_scale_factor=0.5) + # query/value kernels were modified + self.assertFalse( + jnp.allclose( + base["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key/out are untouched + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + unapply_lora_from_base_params_nnx(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_numerical_parity_with_linen_apply(self): + """Same base+lora numbers → same kernel after apply, on either tree shape.""" + rng = jax.random.key(123) + base_nnx = _concrete_base(rng) + base_linen = {"params": jax.tree_util.tree_map(jnp.copy, base_nnx)} + lora = _build_lora_params(base_nnx, _lora_config(rank=2, target_modules=("q_proj",)), jax.random.key(5)) + apply_lora_on_base_params_nnx(base_nnx, lora, lora_scale_factor=0.7) + apply_lora_on_base_params(base_linen, {"params": lora}, lora_scale_factor=0.7) + np.testing.assert_allclose( + np.asarray(base_nnx["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_linen["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-6, + ) + + def test_apply_with_unexpected_lora_key_raises(self): + base = _concrete_base() + bad = {"decoder": {"layers": {"self_attention": {"query": {"unexpected": jnp.zeros((4, 2))}}}}} + with self.assertRaises(ValueError): + apply_lora_on_base_params_nnx(base, bad) + + +class TestLinenLoraRegression(unittest.TestCase): + """Smoke tests for the Linen apply / unapply helpers (no other unit test exercises them).""" + + def _linen_pair(self, rng=None): + """Build a Linen-shape (with `{"params": ...}` outer wrapper) base + lora pair.""" + if rng is None: + rng = jax.random.key(99) + base_inner = _concrete_base(rng) + base = {"params": jax.tree_util.tree_map(jnp.copy, base_inner)} + lora_inner = _build_lora_params( + base_inner, + _lora_config(rank=2, target_modules=("q_proj", "v_proj")), + jax.random.key(7), + ) + lora = {"params": lora_inner} + return base, lora + + def test_linen_apply_then_unapply_is_identity(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=0.5) + unapply_lora_from_base_params(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_linen_apply_only_modifies_target_modules(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=1.0) + # query and value are targets — must change. + self.assertFalse( + jnp.allclose( + base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key and out are non-target — must be untouched. + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + + +if __name__ == "__main__": + unittest.main()