docs: add LoRA tutorial and layer customization guide#3682
docs: add LoRA tutorial and layer customization guide#3682chiajunglien wants to merge 8 commits intoAI-Hypercomputer:jackyf/feat/lora-nnxfrom
Conversation
5388bf9 to
0d49068
Compare
6c12235 to
0dfeb76
Compare
058bd82 to
5ff1ff1
Compare
0dfeb76 to
2f91ad8
Compare
2f91ad8 to
f5736a1
Compare
5ff1ff1 to
4e15311
Compare
b008a24 to
f5736a1
Compare
99f5268 to
f5736a1
Compare
58c7a09 to
800530d
Compare
| export HF_AUTH_TOKEN="hf_YOUR_TOKEN" | ||
| # Mode 3: Convert Base Model only | ||
| python src/maxtext/checkpoint_conversion/to_huggingface.py \ | ||
| src/maxtext/configs/base.yml \ |
There was a problem hiding this comment.
could we remove base.yml?
| has_base = len(str(base_path).strip()) > 0 | ||
| has_lora = len(str(lora_path).strip()) > 0 |
There was a problem hiding this comment.
why we need these? could we just decide by base_path and lora_path
| detected_r = 16 | ||
| if has_lora: | ||
| max_logging.log(f"Loading LoRA: {lora_path}") | ||
| lora_checkpoint_raw = load_orbax_checkpoint(config, custom_path=lora_path) | ||
| lora_params = detect_and_extract_checkpoint(lora_checkpoint_raw) | ||
|
|
||
| # Auto-detect Rank (r) from lora_a shape | ||
| for k, v in lora_params.items(): | ||
| if k.endswith("lora_a"): | ||
| val = np.array(v["value"]) if isinstance(v, dict) else np.array(v) | ||
| detected_r = val.shape[-1] | ||
| max_logging.log(f"Auto-detected LoRA Rank (r): {detected_r}") | ||
| break |
There was a problem hiding this comment.
Could we always use auto detect instead of hardcode or reading from config for r?
| raise ValueError(f"HF Tokenizer ID not found for model key: {model_key}") | ||
| hf_token = config.hf_access_token | ||
| hf_tokenizer_id = HF_IDS[model_key] | ||
| tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_id, token=hf_token) |
There was a problem hiding this comment.
keep origin pattern which is better to declare hf_token then using it
| for mt_key, content in lora_params.items(): | ||
| if mt_key.endswith("-kernel_lora_a"): | ||
| lora_b_key = mt_key.replace("-kernel_lora_a", "-kernel_lora_b") | ||
| lookup_key = mt_key.replace("-kernel_lora_a", "-kernel") | ||
| if lookup_key in param_map: | ||
| wa = np.array(content["value"]) if isinstance(content, dict) else np.array(content) | ||
| wb = np.array(lora_params[lora_b_key]["value"]) if isinstance(lora_params[lora_b_key], dict) else np.array(lora_params[lora_b_key]) | ||
| hf_paths = param_map[lookup_key] | ||
| if not isinstance(hf_paths, list): hf_paths = [hf_paths] | ||
|
|
||
| for i in range(min(wa.shape[1] if wa.ndim > 1 else 1, len(hf_paths))): | ||
| name = hf_paths[i].replace(".weight", "") | ||
| found_hf_modules.add(hf_paths[i].split(".")[-2]) | ||
| if wa.ndim == 3: | ||
| transformed_hf_weights[f"base_model.model.{name}.lora_A.weight"] = wa[:, i, :].T | ||
| transformed_hf_weights[f"base_model.model.{name}.lora_B.weight"] = wb[:, i, :].T | ||
| else: | ||
| transformed_hf_weights[f"base_model.model.{name}.lora_A.weight"] = wa.T | ||
| transformed_hf_weights[f"base_model.model.{name}.lora_B.weight"] = wb.T | ||
| max_logging.log(f"✅ Mapped LoRA: {lookup_key}") |
There was a problem hiding this comment.
Make it a helper function, and see can we reuse the code here with the above merge_logic
| def merge_logic(base_dict, lora_dict): | ||
| for lora_key, a_val in lora_dict.items(): | ||
| if lora_key.endswith("-kernel_lora_a"): | ||
| base_key = lora_key.replace("-kernel_lora_a", "-kernel") | ||
| lora_b_key = lora_key.replace("-kernel_lora_a", "-kernel_lora_b") | ||
| if base_key in base_dict: | ||
| wa = np.array(a_val["value"]) if isinstance(a_val, dict) else np.array(a_val) | ||
| wb = np.array(lora_dict[lora_b_key]["value"]) if isinstance(lora_dict[lora_b_key], dict) else np.array(lora_dict[lora_b_key]) | ||
| base_node = base_dict[base_key] | ||
| base_w = np.array(base_node["value"]) if isinstance(base_node, dict) else np.array(base_node) | ||
|
|
||
| scaling = getattr(config, "lora_alpha", detected_r * 2) / detected_r | ||
| wa_f32, wb_f32 = wa.astype(np.float32), wb.astype(np.float32) | ||
| if wa.ndim == 3: | ||
| delta_w = np.einsum('ihr,rho->ioh', wa_f32, wb_f32) | ||
| # Handle Scanned Layer shape permutations | ||
| if base_w.shape != delta_w.shape: | ||
| if base_w.shape == (wa.shape[1], wa.shape[0], wb.shape[2]): delta_w = delta_w.transpose(2, 0, 1) | ||
| elif base_w.shape == (wa.shape[0], wa.shape[1], wb.shape[2]): delta_w = delta_w.transpose(0, 2, 1) | ||
| else: delta_w = delta_w.reshape(base_w.shape) | ||
| else: delta_w = wa_f32 @ wb_f32 | ||
|
|
||
| final_w = (base_w.astype(np.float32) + (scaling * delta_w)).astype(wa.dtype) | ||
| if isinstance(base_dict[base_key], dict): base_dict[base_key]["value"] = final_w | ||
| else: base_dict[base_key] = final_w | ||
| max_logging.log(f"✅ Merged: {base_key}") |
There was a problem hiding this comment.
Make it a helper function and to see it could be reuse in other mode
|
|
||
|
|
||
| def load_orbax_checkpoint(config) -> dict: | ||
| def load_orbax_checkpoint(config, custom_path=None) -> dict: |
There was a problem hiding this comment.
Why we need custom_path? epath should be workable for local and gs:// which align to load_parameters_path
There was a problem hiding this comment.
I've updated load_orbax_checkpoint to accept a direct checkpoint_path string. This makes the function more generic and leverages epath to handle both local and GCS paths seamlessly. In the main logic, I now pass config.load_parameters_path or config.lora_input_adapters_path explicitly based on the mode.
| # HF LoRA keys: base_model.model.layers.{layer}.{module}.lora_A/B.weight | ||
|
|
||
| # Clean up LoRA suffixes to get the base module path | ||
| # e.g. ...q_proj.lora_A.weight -> ...q_proj | ||
| hf_param_key = hf_key.replace(".lora_A.weight", "").replace(".lora_B.weight", "") | ||
| hf_param_key = hf_param_key.replace(".lora_A", "").replace(".lora_B", "") |
There was a problem hiding this comment.
Please standardize the LoRA parameter names across all scripts to use lora_A or lora_a. Personally, i suggest always use lower case
There was a problem hiding this comment.
I've standardized all internal logic and variables to use lowercase lora_a/lora_b across the scripts to keep it consistent as you suggested.
However, for the final output in to_huggingface.py (specifically for LORA_ONLY mode), I kept the keys as uppercase lora_A/lora_B. I found that downstream inference engines like vLLM have hard-coded checks for these keys and will throw a ValueError if they are lowercase.
(EngineCore pid=967793) File "/mnt/disks/emma_data/vllm/vllm/lora/utils.py", line 194, in parse_fine_tuned_lora_name
(EngineCore pid=967793) raise ValueError(f"{name} is unsupported LoRA weight")
(EngineCore pid=967793) ValueError: base_model.model.model.layers.0.mlp.down_proj.lora_a.weight is unsupported LoRA weight
| if os.path.isdir(adapter_path): | ||
| # Local directory | ||
| adapter_dir = epath.Path(adapter_path) | ||
| adapter_files = list(adapter_dir.glob("*.safetensors")) | ||
| if not adapter_files: | ||
| adapter_files = list(adapter_dir.glob("*.bin")) | ||
| if not adapter_files: | ||
| raise ValueError(f"No LoRA adapter files found in {adapter_path}") | ||
| adapter_file = adapter_files[0] | ||
| else: | ||
| # Assume it's a HF Hub repo ID | ||
| try: | ||
| files = list_repo_files(adapter_path, token=hf_access_token) | ||
| safetensor_files = [f for f in files if f.endswith(".safetensors")] | ||
| if not safetensor_files: | ||
| bin_files = [f for f in files if f.endswith(".bin")] | ||
| if not bin_files: | ||
| raise ValueError(f"No LoRA adapter files found in {adapter_path}") | ||
| adapter_file = bin_files[0] | ||
| else: | ||
| adapter_file = safetensor_files[0] |
There was a problem hiding this comment.
Could we remove the logic for .bin ?
48f0064 to
669b501
Compare
669b501 to
9634d50
Compare
Description
Start with a short description of what the PR does and how this is a change from
the past.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.