Modify lora inference api
This commit is contained in:
@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
|
||||
### API Reference
|
||||
|
||||
```python
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
|
||||
# 1. Load model with LoRA structure
|
||||
# 1. Load model with LoRA structure and weights
|
||||
lora_cfg = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
model = VoxCPMModel.from_local(
|
||||
pretrained_path,
|
||||
optimize=True, # Enable torch.compile acceleration
|
||||
lora_config=lora_cfg
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5", # or local path
|
||||
load_denoiser=False, # Optional: disable denoiser for faster loading
|
||||
optimize=True, # Enable torch.compile acceleration
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path="/path/to/lora_checkpoint",
|
||||
)
|
||||
|
||||
# 2. Load LoRA weights (works after torch.compile)
|
||||
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint")
|
||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
# 2. Generate audio
|
||||
audio = model.generate(
|
||||
text="Hello, this is LoRA fine-tuned result.",
|
||||
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
|
||||
prompt_text="Reference audio transcript", # Optional: for voice cloning
|
||||
)
|
||||
|
||||
# 3. Disable LoRA (use base model only)
|
||||
model.set_lora_enabled(False)
|
||||
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
|
||||
model.set_lora_enabled(True)
|
||||
|
||||
# 5. Unload LoRA (reset weights to zero)
|
||||
model.reset_lora_weights()
|
||||
model.unload_lora()
|
||||
|
||||
# 6. Hot-swap to another LoRA
|
||||
model.load_lora_weights("/path/to/another_lora_checkpoint")
|
||||
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
|
||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
|
||||
# 7. Get current LoRA weights
|
||||
lora_state = model.get_lora_state_dict()
|
||||
```
|
||||
|
||||
### Simplified Usage (Auto LoRA Config)
|
||||
|
||||
If you only have LoRA weights and don't need custom config, just provide the path:
|
||||
|
||||
```python
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5",
|
||||
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
||||
)
|
||||
```
|
||||
|
||||
### Method Reference
|
||||
|
||||
| Method | Description | torch.compile Compatible |
|
||||
|--------|-------------|--------------------------|
|
||||
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ |
|
||||
| `load_lora(path)` | Load LoRA weights from file | ✅ |
|
||||
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
|
||||
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ |
|
||||
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
|
||||
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
|
||||
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
|
||||
|
||||
---
|
||||
|
||||
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
|
||||
### 4. LoRA Not Taking Effect at Inference
|
||||
|
||||
- Ensure inference config matches training config LoRA parameters
|
||||
- Check `load_lora_weights` return value - `skipped_keys` should be empty
|
||||
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||
- Verify `set_lora_enabled(True)` is called
|
||||
|
||||
### 5. Checkpoint Loading Errors
|
||||
|
||||
Reference in New Issue
Block a user