Skip to content

docs: add LoRA tutorial and layer customization guide#3682

Draft
chiajunglien wants to merge 8 commits intoAI-Hypercomputer:jackyf/feat/lora-nnxfrom
CIeNET-International:emma/lora-tutorial-final-v2
Draft

docs: add LoRA tutorial and layer customization guide#3682
chiajunglien wants to merge 8 commits intoAI-Hypercomputer:jackyf/feat/lora-nnxfrom
CIeNET-International:emma/lora-tutorial-final-v2

Conversation

@chiajunglien
Copy link
Copy Markdown

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch from 5388bf9 to 0d49068 Compare April 16, 2026 03:43
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 6c12235 to 0dfeb76 Compare April 16, 2026 09:25
@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch 4 times, most recently from 058bd82 to 5ff1ff1 Compare April 16, 2026 10:12
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 0dfeb76 to 2f91ad8 Compare April 16, 2026 10:32
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 2f91ad8 to f5736a1 Compare April 16, 2026 10:54
@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch from 5ff1ff1 to 4e15311 Compare April 17, 2026 01:04
@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch from b008a24 to f5736a1 Compare April 17, 2026 01:21
@chiajunglien chiajunglien reopened this Apr 17, 2026
@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch from 99f5268 to f5736a1 Compare April 17, 2026 02:30
@chiajunglien chiajunglien reopened this Apr 17, 2026
@chiajunglien chiajunglien force-pushed the emma/lora-tutorial-final-v2 branch from 58c7a09 to 800530d Compare April 17, 2026 08:54
export HF_AUTH_TOKEN="hf_YOUR_TOKEN"
# Mode 3: Convert Base Model only
python src/maxtext/checkpoint_conversion/to_huggingface.py \
src/maxtext/configs/base.yml \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we remove base.yml?

Comment on lines +258 to +259
has_base = len(str(base_path).strip()) > 0
has_lora = len(str(lora_path).strip()) > 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need these? could we just decide by base_path and lora_path

Comment on lines +273 to +285
detected_r = 16
if has_lora:
max_logging.log(f"Loading LoRA: {lora_path}")
lora_checkpoint_raw = load_orbax_checkpoint(config, custom_path=lora_path)
lora_params = detect_and_extract_checkpoint(lora_checkpoint_raw)

# Auto-detect Rank (r) from lora_a shape
for k, v in lora_params.items():
if k.endswith("lora_a"):
val = np.array(v["value"]) if isinstance(v, dict) else np.array(v)
detected_r = val.shape[-1]
max_logging.log(f"Auto-detected LoRA Rank (r): {detected_r}")
break
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we always use auto detect instead of hardcode or reading from config for r?

raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}")
hf_token = config.hf_access_token
hf_tokenizer_id = HF_IDS[model_key]
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_id, token=hf_token)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep origin pattern which is better to declare hf_token then using it

Comment on lines +334 to +353
for mt_key, content in lora_params.items():
if mt_key.endswith("-kernel_lora_a"):
lora_b_key = mt_key.replace("-kernel_lora_a", "-kernel_lora_b")
lookup_key = mt_key.replace("-kernel_lora_a", "-kernel")
if lookup_key in param_map:
wa = np.array(content["value"]) if isinstance(content, dict) else np.array(content)
wb = np.array(lora_params[lora_b_key]["value"]) if isinstance(lora_params[lora_b_key], dict) else np.array(lora_params[lora_b_key])
hf_paths = param_map[lookup_key]
if not isinstance(hf_paths, list): hf_paths = [hf_paths]

for i in range(min(wa.shape[1] if wa.ndim > 1 else 1, len(hf_paths))):
name = hf_paths[i].replace(".weight", "")
found_hf_modules.add(hf_paths[i].split(".")[-2])
if wa.ndim == 3:
transformed_hf_weights[f"base_model.model.{name}.lora_A.weight"] = wa[:, i, :].T
transformed_hf_weights[f"base_model.model.{name}.lora_B.weight"] = wb[:, i, :].T
else:
transformed_hf_weights[f"base_model.model.{name}.lora_A.weight"] = wa.T
transformed_hf_weights[f"base_model.model.{name}.lora_B.weight"] = wb.T
max_logging.log(f"✅ Mapped LoRA: {lookup_key}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it a helper function, and see can we reuse the code here with the above merge_logic

Comment on lines +289 to +314
def merge_logic(base_dict, lora_dict):
for lora_key, a_val in lora_dict.items():
if lora_key.endswith("-kernel_lora_a"):
base_key = lora_key.replace("-kernel_lora_a", "-kernel")
lora_b_key = lora_key.replace("-kernel_lora_a", "-kernel_lora_b")
if base_key in base_dict:
wa = np.array(a_val["value"]) if isinstance(a_val, dict) else np.array(a_val)
wb = np.array(lora_dict[lora_b_key]["value"]) if isinstance(lora_dict[lora_b_key], dict) else np.array(lora_dict[lora_b_key])
base_node = base_dict[base_key]
base_w = np.array(base_node["value"]) if isinstance(base_node, dict) else np.array(base_node)

scaling = getattr(config, "lora_alpha", detected_r * 2) / detected_r
wa_f32, wb_f32 = wa.astype(np.float32), wb.astype(np.float32)
if wa.ndim == 3:
delta_w = np.einsum('ihr,rho->ioh', wa_f32, wb_f32)
# Handle Scanned Layer shape permutations
if base_w.shape != delta_w.shape:
if base_w.shape == (wa.shape[1], wa.shape[0], wb.shape[2]): delta_w = delta_w.transpose(2, 0, 1)
elif base_w.shape == (wa.shape[0], wa.shape[1], wb.shape[2]): delta_w = delta_w.transpose(0, 2, 1)
else: delta_w = delta_w.reshape(base_w.shape)
else: delta_w = wa_f32 @ wb_f32

final_w = (base_w.astype(np.float32) + (scaling * delta_w)).astype(wa.dtype)
if isinstance(base_dict[base_key], dict): base_dict[base_key]["value"] = final_w
else: base_dict[base_key] = final_w
max_logging.log(f"✅ Merged: {base_key}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it a helper function and to see it could be reuse in other mode



def load_orbax_checkpoint(config) -> dict:
def load_orbax_checkpoint(config, custom_path=None) -> dict:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need custom_path? epath should be workable for local and gs:// which align to load_parameters_path

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated load_orbax_checkpoint to accept a direct checkpoint_path string. This makes the function more generic and leverages epath to handle both local and GCS paths seamlessly. In the main logic, I now pass config.load_parameters_path or config.lora_input_adapters_path explicitly based on the mode.

Comment on lines +373 to +378
# HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight

# Clean up LoRA suffixes to get the base module path
# e.g. ...q_proj.lora_A.weight -> ...q_proj
hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "")
hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please standardize the LoRA parameter names across all scripts to use lora_A or lora_a. Personally, i suggest always use lower case

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've standardized all internal logic and variables to use lowercase lora_a/lora_b across the scripts to keep it consistent as you suggested.

However, for the final output in to_huggingface.py (specifically for LORA_ONLY mode), I kept the keys as uppercase lora_A/lora_B. I found that downstream inference engines like vLLM have hard-coded checks for these keys and will throw a ValueError if they are lowercase.

(EngineCore pid=967793)   File "/mnt/disks/emma_data/vllm/vllm/lora/utils.py", line 194, in parse_fine_tuned_lora_name
(EngineCore pid=967793)     raise ValueError(f"{name} is unsupported LoRA weight")
(EngineCore pid=967793) ValueError: base_model.model.model.layers.0.mlp.down_proj.lora_a.weight is unsupported LoRA weight

Comment on lines +331 to +351
if os.path.isdir(adapter_path):
# Local directory
adapter_dir = epath.Path(adapter_path)
adapter_files = list(adapter_dir.glob("*.safetensors"))
if not adapter_files:
adapter_files = list(adapter_dir.glob("*.bin"))
if not adapter_files:
raise ValueError(f"No LoRA adapter files found in {adapter_path}")
adapter_file = adapter_files[0]
else:
# Assume it's a HF Hub repo ID
try:
files = list_repo_files(adapter_path, token=hf_access_token)
safetensor_files = [f for f in files if f.endswith(".safetensors")]
if not safetensor_files:
bin_files = [f for f in files if f.endswith(".bin")]
if not bin_files:
raise ValueError(f"No LoRA adapter files found in {adapter_path}")
adapter_file = bin_files[0]
else:
adapter_file = safetensor_files[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove the logic for .bin ?

@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch 3 times, most recently from 48f0064 to 669b501 Compare April 21, 2026 08:45
@RexBearIU RexBearIU force-pushed the jackyf/feat/lora-nnx branch from 669b501 to 9634d50 Compare April 21, 2026 08:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants