Modify lora inference api

This commit is contained in:
刘鑫
2025-12-05 22:22:13 +08:00
parent b1f7593ae0
commit 400f47a516
5 changed files with 265 additions and 139 deletions

View File

@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
### API Reference
```python
from voxcpm.model import VoxCPMModel
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# 1. Load model with LoRA structure
# 1. Load model with LoRA structure and weights
lora_cfg = LoRAConfig(
enable_lm=True,
enable_dit=True,
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
model = VoxCPMModel.from_local(
pretrained_path,
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5", # or local path
load_denoiser=False, # Optional: disable denoiser for faster loading
optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg
lora_config=lora_cfg,
lora_weights_path="/path/to/lora_checkpoint",
)
# 2. Load LoRA weights (works after torch.compile)
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 2. Generate audio
audio = model.generate(
text="Hello, this is LoRA fine-tuned result.",
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
prompt_text="Reference audio transcript", # Optional: for voice cloning
)
# 3. Disable LoRA (use base model only)
model.set_lora_enabled(False)
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
model.set_lora_enabled(True)
# 5. Unload LoRA (reset weights to zero)
model.reset_lora_weights()
model.unload_lora()
# 6. Hot-swap to another LoRA
model.load_lora_weights("/path/to/another_lora_checkpoint")
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 7. Get current LoRA weights
lora_state = model.get_lora_state_dict()
```
### Simplified Usage (Auto LoRA Config)
If you only have LoRA weights and don't need custom config, just provide the path:
```python
from voxcpm.core import VoxCPM
# Auto-create default LoRAConfig when only lora_weights_path is provided
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5",
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
)
```
### Method Reference
| Method | Description | torch.compile Compatible |
|--------|-------------|--------------------------|
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ |
| `load_lora(path)` | Load LoRA weights from file | ✅ |
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ |
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
---
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
### 4. LoRA Not Taking Effect at Inference
- Ensure inference config matches training config LoRA parameters
- Check `load_lora_weights` return value - `skipped_keys` should be empty
- Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called
### 5. Checkpoint Loading Errors

View File

@@ -3,7 +3,7 @@
Full finetune inference script (no LoRA).
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
can be loaded directly via VoxCPMModel.from_local().
can be loaded directly via VoxCPM.
Usage:
@@ -26,9 +26,8 @@ import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.core import VoxCPM
def parse_args():
@@ -81,49 +80,52 @@ def parse_args():
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
def main():
args = parse_args()
# Load model from checkpoint directory
# Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False)
model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir,
load_denoiser=False,
optimize=True,
)
# Run inference
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'")
if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
print(f"[FT Inference] Reference text: {prompt_text}")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
# Squeeze and save audio
if isinstance(audio, torch.Tensor):
audio_np = audio.squeeze(0).cpu().numpy()
else:
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
# Save audio
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.sample_rate)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s")
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
if __name__ == "__main__":
main()

View File

@@ -25,9 +25,8 @@ import argparse
from pathlib import Path
import soundfile as sf
import torch
from voxcpm.model import VoxCPMModel
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
@@ -88,6 +87,11 @@ def parse_args():
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
@@ -100,123 +104,114 @@ def main():
lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. Load base model (with LoRA structure and torch.compile)
print(f"[1/3] Loading base model: {pretrained_path}")
model = VoxCPMModel.from_local(
pretrained_path,
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
training=False,
lora_config=lora_cfg,
)
# Debug: check DiT param paths after compile
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
# 3. Load LoRA weights (works after compile)
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
# 2. Check LoRA checkpoint
ckpt_dir = args.lora_ckpt
if not Path(ckpt_dir).exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
print(f"[2/3] Loading LoRA weights: {ckpt_dir}")
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
print(f" Loaded {len(loaded)} parameters")
if skipped:
print(f"[WARNING] Skipped {len(skipped)} parameters")
print(f" Skipped keys (first 5): {skipped[:5]}")
# 3. Load model with LoRA (no denoiser)
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=True,
lora_config=lora_cfg,
lora_weights_path=ckpt_dir,
)
# 4. Synthesize audio
prompt_wav_path = args.prompt_audio or ""
prompt_text = args.prompt_text or ""
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[3/3] Starting synthesis tests...")
print(f"\n[2/2] Starting synthesis tests...")
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True)
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...")
model.reset_lora_weights()
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
model.unload_lora()
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 5: Hot-reload LoRA (load_lora_weights) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...")
loaded, _ = model.load_lora_weights(str(ckpt_dir))
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, skipped = model.load_lora(str(ckpt_dir))
print(f" Reloaded {len(loaded)} parameters")
with torch.inference_mode():
audio = model.generate(
target_text=args.text,
prompt_text=prompt_text,
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
max_len=args.max_len,
inference_timesteps=args.inference_timesteps,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}")
@@ -228,5 +223,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
"ZIPENHANCER_MODEL_PATH", None
)
# Build LoRA config if lora_path is provided
lora_config = None
lora_weights_path = getattr(args, "lora_path", None)
if lora_weights_path:
from voxcpm.model.voxcpm import LoRAConfig
lora_config = LoRAConfig(
enable_lm=getattr(args, "lora_enable_lm", True),
enable_dit=getattr(args, "lora_enable_dit", True),
enable_proj=getattr(args, "lora_enable_proj", False),
r=getattr(args, "lora_r", 32),
alpha=getattr(args, "lora_alpha", 16),
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
# Load from local path if provided
if getattr(args, "model_path", None):
try:
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (local).")
return model
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None),
local_files_only=getattr(args, "local_files_only", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (from_pretrained).")
return model
@@ -256,6 +276,15 @@ Examples:
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
# LoRA parameters
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
return parser

View File

@@ -2,9 +2,9 @@ import os
import re
import tempfile
import numpy as np
from typing import Generator
from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel
from .model.voxcpm import VoxCPMModel, LoRAConfig
class VoxCPM:
def __init__(self,
@@ -12,6 +12,8 @@ class VoxCPM:
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline.
@@ -23,9 +25,30 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
enable_lm=True,
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}")
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
@@ -46,6 +69,8 @@ class VoxCPM:
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -59,6 +84,12 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
after model initialization.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
@@ -90,6 +121,8 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,
)
@@ -197,3 +230,51 @@ class VoxCPM:
os.unlink(temp_prompt_wav_path)
except OSError:
pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
if self.tts_model.lora_config is None:
raise RuntimeError(
"Cannot load LoRA weights: model was not initialized with LoRA config. "
"Please reinitialize with lora_config or lora_weights_path parameter."
)
return self.tts_model.load_lora_weights(lora_weights_path)
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
return self.tts_model.lora_config is not None