Modify lora inference api

This commit is contained in:
刘鑫
2025-12-05 22:22:13 +08:00
parent b1f7593ae0
commit 400f47a516
5 changed files with 265 additions and 139 deletions

View File

@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
### API Reference ### API Reference
```python ```python
from voxcpm.model import VoxCPMModel from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
# 1. Load model with LoRA structure # 1. Load model with LoRA structure and weights
lora_cfg = LoRAConfig( lora_cfg = LoRAConfig(
enable_lm=True, enable_lm=True,
enable_dit=True, enable_dit=True,
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"], target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"], target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
) )
model = VoxCPMModel.from_local( model = VoxCPM.from_pretrained(
pretrained_path, hf_model_id="openbmb/VoxCPM1.5", # or local path
load_denoiser=False, # Optional: disable denoiser for faster loading
optimize=True, # Enable torch.compile acceleration optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg lora_config=lora_cfg,
lora_weights_path="/path/to/lora_checkpoint",
) )
# 2. Load LoRA weights (works after torch.compile) # 2. Generate audio
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint") audio = model.generate(
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}") text="Hello, this is LoRA fine-tuned result.",
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
prompt_text="Reference audio transcript", # Optional: for voice cloning
)
# 3. Disable LoRA (use base model only) # 3. Disable LoRA (use base model only)
model.set_lora_enabled(False) model.set_lora_enabled(False)
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
model.set_lora_enabled(True) model.set_lora_enabled(True)
# 5. Unload LoRA (reset weights to zero) # 5. Unload LoRA (reset weights to zero)
model.reset_lora_weights() model.unload_lora()
# 6. Hot-swap to another LoRA # 6. Hot-swap to another LoRA
model.load_lora_weights("/path/to/another_lora_checkpoint") loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 7. Get current LoRA weights # 7. Get current LoRA weights
lora_state = model.get_lora_state_dict() lora_state = model.get_lora_state_dict()
``` ```
### Simplified Usage (Auto LoRA Config)
If you only have LoRA weights and don't need custom config, just provide the path:
```python
from voxcpm.core import VoxCPM
# Auto-create default LoRAConfig when only lora_weights_path is provided
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5",
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
)
```
### Method Reference ### Method Reference
| Method | Description | torch.compile Compatible | | Method | Description | torch.compile Compatible |
|--------|-------------|--------------------------| |--------|-------------|--------------------------|
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ | | `load_lora(path)` | Load LoRA weights from file | ✅ |
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ | | `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ | | `unload_lora()` | Reset LoRA weights to initial values | ✅ |
| `get_lora_state_dict()` | Get current LoRA weights | ✅ | | `get_lora_state_dict()` | Get current LoRA weights | ✅ |
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
--- ---
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
### 4. LoRA Not Taking Effect at Inference ### 4. LoRA Not Taking Effect at Inference
- Ensure inference config matches training config LoRA parameters - Ensure inference config matches training config LoRA parameters
- Check `load_lora_weights` return value - `skipped_keys` should be empty - Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called - Verify `set_lora_enabled(True)` is called
### 5. Checkpoint Loading Errors ### 5. Checkpoint Loading Errors

View File

@@ -3,7 +3,7 @@
Full finetune inference script (no LoRA). Full finetune inference script (no LoRA).
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.), Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
can be loaded directly via VoxCPMModel.from_local(). can be loaded directly via VoxCPM.
Usage: Usage:
@@ -26,9 +26,8 @@ import argparse
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel from voxcpm.core import VoxCPM
def parse_args(): def parse_args():
@@ -81,49 +80,52 @@ def parse_args():
default=600, default=600,
help="Max generation steps", help="Max generation steps",
) )
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args() return parser.parse_args()
def main(): def main():
args = parse_args() args = parse_args()
# Load model from checkpoint directory # Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}") print(f"[FT Inference] Loading model: {args.ckpt_dir}")
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False) model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir,
load_denoiser=False,
optimize=True,
)
# Run inference # Run inference
prompt_wav_path = args.prompt_audio or "" prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text or "" prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'") print(f"[FT Inference] Synthesizing: text='{args.text}'")
if prompt_wav_path: if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}") print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
print(f"[FT Inference] Reference text: {prompt_text}") print(f"[FT Inference] Reference text: {prompt_text}")
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
# Squeeze and save audio # Save audio
if isinstance(audio, torch.Tensor):
audio_np = audio.squeeze(0).cpu().numpy()
else:
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
out_path = Path(args.output) out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.sample_rate) sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -25,9 +25,8 @@ import argparse
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config from voxcpm.training.config import load_yaml_config
@@ -88,6 +87,11 @@ def parse_args():
default=600, default=600,
help="Max generation steps", help="Max generation steps",
) )
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args() return parser.parse_args()
@@ -100,123 +104,114 @@ def main():
lora_cfg_dict = cfg.get("lora", {}) or {} lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. Load base model (with LoRA structure and torch.compile) # 2. Check LoRA checkpoint
print(f"[1/3] Loading base model: {pretrained_path}") ckpt_dir = args.lora_ckpt
model = VoxCPMModel.from_local( if not Path(ckpt_dir).exists():
pretrained_path,
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
training=False,
lora_config=lora_cfg,
)
# Debug: check DiT param paths after compile
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
# 3. Load LoRA weights (works after compile)
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}") raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
print(f"[2/3] Loading LoRA weights: {ckpt_dir}") # 3. Load model with LoRA (no denoiser)
loaded, skipped = model.load_lora_weights(str(ckpt_dir)) print(f"[1/2] Loading model with LoRA: {pretrained_path}")
print(f" Loaded {len(loaded)} parameters") print(f" LoRA weights: {ckpt_dir}")
if skipped: model = VoxCPM.from_pretrained(
print(f"[WARNING] Skipped {len(skipped)} parameters") hf_model_id=pretrained_path,
print(f" Skipped keys (first 5): {skipped[:5]}") load_denoiser=False,
optimize=True,
lora_config=lora_cfg,
lora_weights_path=ckpt_dir,
)
# 4. Synthesize audio # 4. Synthesize audio
prompt_wav_path = args.prompt_audio or "" prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text or "" prompt_text = args.prompt_text if args.prompt_text else None
out_path = Path(args.output) out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True) out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[3/3] Starting synthesis tests...") print(f"\n[2/2] Starting synthesis tests...")
# === Test 1: With LoRA === # === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...") print(f"\n [Test 1] Synthesize with LoRA...")
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
lora_output = out_path.with_stem(out_path.stem + "_with_lora") lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.sample_rate) sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 2: Disable LoRA (via set_lora_enabled) === # === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...") print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False) model.set_lora_enabled(False)
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.sample_rate) sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 3: Re-enable LoRA === # === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...") print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True) model.set_lora_enabled(True)
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.sample_rate) sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 4: Unload LoRA (reset_lora_weights) === # === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...") print(f"\n [Test 4] Unload LoRA (unload_lora)...")
model.reset_lora_weights() model.unload_lora()
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reset_output = out_path.with_stem(out_path.stem + "_lora_reset") reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.sample_rate) sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 5: Hot-reload LoRA (load_lora_weights) === # === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...") print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, _ = model.load_lora_weights(str(ckpt_dir)) loaded, skipped = model.load_lora(str(ckpt_dir))
print(f" Reloaded {len(loaded)} parameters") print(f" Reloaded {len(loaded)} parameters")
with torch.inference_mode(): audio_np = model.generate(
audio = model.generate( text=args.text,
target_text=args.text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path, prompt_wav_path=prompt_wav_path,
max_len=args.max_len, prompt_text=prompt_text,
inference_timesteps=args.inference_timesteps,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
) )
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.sample_rate) sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f"\n[Done] All tests completed!") print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}") print(f" - with_lora: {lora_output}")
@@ -228,5 +223,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
"ZIPENHANCER_MODEL_PATH", None "ZIPENHANCER_MODEL_PATH", None
) )
# Build LoRA config if lora_path is provided
lora_config = None
lora_weights_path = getattr(args, "lora_path", None)
if lora_weights_path:
from voxcpm.model.voxcpm import LoRAConfig
lora_config = LoRAConfig(
enable_lm=getattr(args, "lora_enable_lm", True),
enable_dit=getattr(args, "lora_enable_dit", True),
enable_proj=getattr(args, "lora_enable_proj", False),
r=getattr(args, "lora_r", 32),
alpha=getattr(args, "lora_alpha", 16),
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
# Load from local path if provided # Load from local path if provided
if getattr(args, "model_path", None): if getattr(args, "model_path", None):
try: try:
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
voxcpm_model_path=args.model_path, voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path, zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False), enable_denoiser=not getattr(args, "no_denoiser", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
) )
print("Model loaded (local).") print("Model loaded (local).")
return model return model
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
zipenhancer_model_id=zipenhancer_path, zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None), cache_dir=getattr(args, "cache_dir", None),
local_files_only=getattr(args, "local_files_only", False), local_files_only=getattr(args, "local_files_only", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
) )
print("Model loaded (from_pretrained).") print("Model loaded (from_pretrained).")
return model return model
@@ -256,6 +276,15 @@ Examples:
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)") parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
# LoRA parameters
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
return parser return parser

View File

@@ -2,9 +2,9 @@ import os
import re import re
import tempfile import tempfile
import numpy as np import numpy as np
from typing import Generator from typing import Generator, Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel from .model.voxcpm import VoxCPMModel, LoRAConfig
class VoxCPM: class VoxCPM:
def __init__(self, def __init__(self,
@@ -12,6 +12,8 @@ class VoxCPM:
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True, enable_denoiser : bool = True,
optimize: bool = True, optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
): ):
"""Initialize VoxCPM TTS pipeline. """Initialize VoxCPM TTS pipeline.
@@ -23,9 +25,30 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized. id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline. enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
""" """
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
enable_lm=True,
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}")
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
self.text_normalizer = None self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None: if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer from .zipenhancer import ZipEnhancer
@@ -46,6 +69,8 @@ class VoxCPM:
cache_dir: str = None, cache_dir: str = None,
local_files_only: bool = False, local_files_only: bool = False,
optimize: bool = True, optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs, **kwargs,
): ):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -59,6 +84,12 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot. cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt local_files_only: If True, only use local files and do not attempt
to download. to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
after model initialization.
Kwargs: Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor. Additional keyword arguments passed to the ``VoxCPM`` constructor.
@@ -90,6 +121,8 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser, enable_denoiser=load_denoiser,
optimize=optimize, optimize=optimize,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs, **kwargs,
) )
@@ -197,3 +230,51 @@ class VoxCPM:
os.unlink(temp_prompt_wav_path) os.unlink(temp_prompt_wav_path)
except OSError: except OSError:
pass pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
if self.tts_model.lora_config is None:
raise RuntimeError(
"Cannot load LoRA weights: model was not initialized with LoRA config. "
"Please reinitialize with lora_config or lora_weights_path parameter."
)
return self.tts_model.load_lora_weights(lora_weights_path)
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
return self.tts_model.lora_config is not None