feat(lora): save/restore LoRA config in checkpoint metadata#4269
feat(lora): save/restore LoRA config in checkpoint metadata#4269RexBearIU wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
shralex
left a comment
There was a problem hiding this comment.
Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.
187905b to
cd17578
Compare
|
Hi @shralex, thank you for the feedback! I have fully addressed your comments with the following changes:
Please let me know if you would like any other enhancements! |
| max_logging.log(f"Elapse for transform and save: {(time.time() - start) / 60:.2f} min") | ||
|
|
||
|
|
||
| def sync_lora_metadata(config) -> None: |
There was a problem hiding this comment.
can we import and reuse this function from lora_utils ?
There was a problem hiding this comment.
Hi @shralex, in our latest iteration we actually removed sync_lora_metadata from lora_utils.py entirely! This was done to keep SFT training/fine-tuning paths strict and 'fail-fast' on configuration mismatches (letting runs crash immediately on mismatched checkpoint configs). Since the synchronization function is no longer part of lora_utils.py, we keep it isolated exclusively inside to_huggingface.py for conversion only.
There was a problem hiding this comment.
moved back to lora_utils to re-use
cd17578 to
1b15640
Compare
1b15640 to
ae44adc
Compare
69c78a7 to
a701719
Compare
a701719 to
07c5e19
Compare
added the logic to re-use the metadata for checkpoint restore. |
43370d8 to
5940e65
Compare
5940e65 to
9bc253e
Compare
Description
This PR implements native serialization of LoRA configuration parameters (
lora_rank,lora_alpha) in standard Orbax_CHECKPOINT_METADATAfiles, and automatically restores them during checkpoint-to-Hugging Face conversion.Why is this change being made?
Previously, users had to manually supply matching
lora.lora_rankandlora.lora_alphaparameters when converting MaxText checkpoints to Hugging Face format. Storing them in Orbax metadata makes the conversion seamless and error-free (resolves @igorts-git's request in #3970).Key Implementation Details
save_checkpoint(checkpointing.py), we save the activeconfig.lorablock under the"lora"key in Orbax'scustom_metadatawhen a LoRA rank is specified.main(to_huggingface.py),sync_lora_metadatareads the custom metadata fromlora_restore_pathviaocp.StandardCheckpointerand overrides active config parameters during conversion.hf_checkpoint_conversion_test.pyto move dynamically loaded inline imports to global top-level imports and completely removedjsonimport since JSON string is written directly.BUGS: #3970
Tests
We have verified the implementation with complete suite-level and individual unit-tests:
SyncLoRAMetadataTestintests/unit/hf_checkpoint_conversion_test.pyto verify the auto-resolving mechanism during Hugging Face conversion.python tests/unit/hf_checkpoint_conversion_test.pyAll tests pass successfully.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.