Modify lora inference api

This commit is contained in:
刘鑫
2025-12-05 22:22:13 +08:00
parent b1f7593ae0
commit 400f47a516
5 changed files with 265 additions and 139 deletions

View File

@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
### API Reference
```python
from voxcpm.model import VoxCPMModel
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# 1. Load model with LoRA structure
# 1. Load model with LoRA structure and weights
lora_cfg = LoRAConfig(
enable_lm=True,
enable_dit=True,
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5", # or local path
load_denoiser=False, # Optional: disable denoiser for faster loading
optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg,
lora_weights_path="/path/to/lora_checkpoint",
)
# 2. Load LoRA weights (works after torch.compile)
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 2. Generate audio
audio = model.generate(
text="Hello, this is LoRA fine-tuned result.",
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
prompt_text="Reference audio transcript", # Optional: for voice cloning
)
# 3. Disable LoRA (use base model only)
model.set_lora_enabled(False)
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
model.set_lora_enabled(True)
# 5. Unload LoRA (reset weights to zero)
model.reset_lora_weights()
model.unload_lora()
# 6. Hot-swap to another LoRA
model.load_lora_weights("/path/to/another_lora_checkpoint")
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 7. Get current LoRA weights
lora_state = model.get_lora_state_dict()
```
### Simplified Usage (Auto LoRA Config)
If you only have LoRA weights and don't need custom config, just provide the path:
```python
from voxcpm.core import VoxCPM
# Auto-create default LoRAConfig when only lora_weights_path is provided
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5",
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
)
```
### Method Reference
| Method | Description | torch.compile Compatible |
|--------|-------------|--------------------------|
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ |
| `load_lora(path)` | Load LoRA weights from file | ✅ |
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ |
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
---
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
### 4. LoRA Not Taking Effect at Inference
- Ensure inference config matches training config LoRA parameters
- Check `load_lora_weights` return value - `skipped_keys` should be empty
- Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called
### 5. Checkpoint Loading Errors