Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_jupyter_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ jobs:
# Run Hugging Face authentication
hf auth login --token "$HF_TOKEN"

for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do
for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl,lora}*.ipynb; do
filename=$(basename "$notebook")
# TODO: Update runnner to v6e-8 as RL with LLama3.1-8b doesn't fit on v6e-4
if [[ "$filename" == "sft_llama3_demo_gpu.ipynb" || "$filename" == "maxtext_with_gepa.ipynb" ]]; then
Expand Down
8 changes: 8 additions & 0 deletions docs/guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ Interactive development guides for running MaxText on Google Colab or local Jupy
A step-by-step guide for the community to help expand MaxText's model library.
:::

:::{grid-item-card} 🎗️ LoRA Model Bringup
:link: guides/lora_model_bringup
:link-type: doc

Learn how to integrate Low-Rank Adaptation (LoRA) support for a new model architecture.
:::

:::{grid-item-card} 🎓 Distillation
:link: guides/distillation
:link-type: doc
Expand All @@ -89,6 +96,7 @@ guides/checkpointing_solutions.md
guides/monitoring_and_debugging.md
guides/run_python_notebook.md
guides/model_bringup.md
guides/lora_model_bringup.md
guides/distillation.md
guides/eval_framework.md
```
113 changes: 113 additions & 0 deletions docs/guides/lora_model_bringup.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
<!--
Copyright 2026 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Adding a New Model for LoRA Fine-Tuning

This guide explains how to add Low-Rank Adaptation (LoRA) support for a new model architecture in MaxText.

MaxText leverages [Tunix](https://github.com/google/tunix) and [Qwix](https://github.com/google/qwix) to support Parameter-Efficient Fine-Tuning (PEFT) on JAX/NNX model definitions. Since the architecture uses modular APIs, adding LoRA support for a new model is highly streamlined.

______________________________________________________________________

## 1. Step-by-Step Bring-up Guide for NNX LoRA

To enable LoRA support for a new model, follow these two simple steps:

### Step 1.1: Verify Base Model Support

The target model architecture must already be implemented and supported as a base model in MaxText.

- The JAX/NNX model definition should be located under `src/maxtext/models/` (e.g., \[gemma3.py\](../../src/maxtext/models/gemma3.py)).
- The model configurations must be registered and runnable for baseline pre-training or full fine-tuning.

### Step 1.2: Define Trainable LoRA Target Modules

Add a recommended target pattern for your model architecture prefix in \[src/maxtext/configs/post_train/lora_module_path.yml\](../../src/maxtext/configs/post_train/lora_module_path.yml):

```yaml
your_model_prefix: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
```

> [!NOTE]
> MaxText's `_get_lora_module_path` in `lora_utils.py` automatically handles both **scanned** (e.g., `layers/0/self_attention/...`) and **unscanned** (e.g., `layers/self_attention/...`) layer formats by injecting an optional layer index regex. You only need to define standard, unscanned paths.

If no prefix matches your model name, MaxText falls back to the `default` pattern:

```yaml
default: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
```

______________________________________________________________________

## 2. Integrating Custom Weight Mappings (When is it needed?)
Comment thread
igorts-git marked this conversation as resolved.

Determining whether you need to implement custom weight mappings depends entirely on your downstream workflow:

### Scenario A: SFT Training & Conversion to PEFT (No Mapping Needed)

If you only need to run SFT fine-tuning with LoRA and then export the adapter back to Hugging Face format using `to_huggingface.py`, **you do not need to write any custom weight mappings.**

- The conversion utility automatically maps, scales, and formats the LoRA adapter parameters back into standard Hugging Face PEFT format based on the base model's existing weight mapping.

### Scenario B: Decoding with the MaxText vLLM Adapter (Mapping is Required)

If you want to perform decoding or run high-performance serving on your adapted model using the **MaxText vLLM adapter** (e.g., via `vllm_decode`), **you must define and register a custom weight mapping.** This allows the vLLM JAX wrapper to dynamically map and feed weights to the vLLM engine.

To add weight mapping for vLLM decode:

1. **Create a Weight Mapping Config**:
Create a new file in \[src/maxtext/integration/tunix/weight_mapping/\](../../src/maxtext/integration/tunix/weight_mapping/) (e.g., `your_model.py`) defining a mapping dataclass. You can refer to \[gemma3.py\](../../src/maxtext/integration/tunix/weight_mapping/gemma3.py) or \[llama3.py\](../../src/maxtext/integration/tunix/weight_mapping/llama3.py) as templates.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our documentation is usually surfaced via https://maxtext.readthedocs.io/en/latest/index.html I am not sure that the hyperlinks to code in GitHub would work here. Please check how code links are implemented in other docs. We are also trying to keep docs consistent w.r.t. MaxText release versions. Such that if someone is reading docs for version 0.2.3, all hyperlinks also point to the same version.

@melissawm do you know how to correctly link to code in GitHub?

@melissawm melissawm Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these will not work. The readthedocs site can only see relative links to documents under the docs/ folder. To link to files under src/ or other folders, the best way is to use the github link (in this case, https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/integration/tunix/weight_mapping/llama3.py)


Your class should specify:

- `to_hf_mapping()`: Maps MaxText base parameters to Hugging Face parameters and specifies their sharding axes.
- `to_hf_hook_fns()`: Custom hook functions for complex weight transformations (e.g., RoPE reordering or query scaling).
- `lora_to_hf_mappings()`: Custom mapping for LoRA weights if they require different handling.

2. **Register the Mapping**:
Register your new class in \[src/maxtext/integration/tunix/weight_mapping/__init__.py\](../../src/maxtext/integration/tunix/weight_mapping/__init__.py) inside the `StandaloneVllmWeightMapping` class:

```python
# Inside StandaloneVllmWeightMapping
if name.startswith("your_model_name"):
return YOUR_MODEL_VLLM_MAPPING
```

______________________________________________________________________

## 3. Verifying Your Custom LoRA Targets

If you are developing or bringing up your model architecture interactively (e.g., in a Python interpreter or Jupyter notebook), you can verify which layers are wrapped with LoRA adapters by inspecting the model's module graph:

```python
import re
from flax import nnx
from maxtext.utils import model_creation_utils, lora_utils

# 1. Create model, mesh, and load config
model, mesh = model_creation_utils.from_pretrained(mt_config)

# 2. Extract lora path regex and compile
compiled_module_path = re.compile(lora_utils._get_lora_module_path(mt_config))

# 3. Iterate over modules to see exactly which ones matched your pattern
for path, _ in nnx.iter_modules(model):
module_path = "/".join(str(p) for p in path)
if compiled_module_path.search(module_path):
print(f"Matched and wrapped with LoRA: {module_path}")
```

This programmatic verification allows you to inspect and traverse parameters interactively during development.
4 changes: 4 additions & 0 deletions docs/guides/model_bringup.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ Please ensure all items on the following checklist are completed before finalizi

- [ ] Create a user guide and post an announcement in the MaxText repo.

5. (Optional) LoRA Support

- [ ] Integrate LoRA support for the newly onboarded model by following the [LoRA Model Bringup Guide](lora_model_bringup.md).

## Community Q&A (FAQ)

**Q: How do I debug code inside a JAX JIT function?**
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/post_training_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ MaxText was co-designed with key Google led innovations to provide a unified pos
- [SFT on Multi-Host TPUs](./posttraining/sft_on_multi_host.md)
- **LoRA (Low-Rank Adaptation)**
- [LoRA on Single-Host TPUs](./posttraining/lora.md)
- [LoRA on Multi-Host TPUs](./posttraining/lora_on_multi_host.md)
- **DPO (Direct Preference Optimization) and ORPO (Odds-Ratio Policy Optimization)**
- [DPO/ORPO on Single-Host TPUs](./posttraining/dpo.md)
- **Multimodal SFT**
Expand Down Expand Up @@ -76,6 +77,7 @@ posttraining/rl_on_multi_host.md
posttraining/rl_qwen3_30b.md
posttraining/knowledge_distillation.md
posttraining/lora.md
posttraining/lora_on_multi_host.md
posttraining/multimodal.md
posttraining/full_finetuning.md
posttraining/gepa_optimization.md
Expand Down
3 changes: 0 additions & 3 deletions docs/tutorials/posttraining/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ export DATASET_NAME=<DATASET_NAME> # e.g., openai/gsm8k
export TRAIN_SPLIT=<TRAIN_SPLIT> # e.g., train
export HF_DATA_DIR=<DATASET_PATH> # e.g., main
export TRAIN_DATA_COLUMNS=<DATA_COLUMNS> # e.g., ['question','answer']
export CHAT_TEMPLATE_PATH=<TEMPLATE_PATH> # e.g., maxtext/examples/chat_templates/math_qa.json

# -- LoRA Conversion configuration (Optional) --
export HF_LORA_ADAPTER_PATH=<HF_LORA_ADAPTER_PATH> # e.g., 'username/adapter-name'
Expand Down Expand Up @@ -118,7 +117,6 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
max_target_length="${MAX_TARGET_LENGTH?}" \
learning_rate="${LEARNING_RATE?}" \
chat_template_path="${CHAT_TEMPLATE_PATH?}" \
enable_nnx=True \
pure_nnx_decoder=True \
lora.enable_lora=True \
Expand Down Expand Up @@ -176,7 +174,6 @@ python3 -m maxtext.trainers.post_train.sft.train_sft \
per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \
max_target_length="${MAX_TARGET_LENGTH?}" \
learning_rate="${LEARNING_RATE?}" \
chat_template_path="${CHAT_TEMPLATE_PATH?}" \
enable_nnx=True \
pure_nnx_decoder=True \
lora.enable_lora=True \
Expand Down
Loading
Loading