mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 03:48:12 +00:00
Modify lora inference api
This commit is contained in:
@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
|
||||
### API Reference
|
||||
|
||||
```python
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
|
||||
# 1. Load model with LoRA structure
|
||||
# 1. Load model with LoRA structure and weights
|
||||
lora_cfg = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
model = VoxCPMModel.from_local(
|
||||
pretrained_path,
|
||||
optimize=True, # Enable torch.compile acceleration
|
||||
lora_config=lora_cfg
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5", # or local path
|
||||
load_denoiser=False, # Optional: disable denoiser for faster loading
|
||||
optimize=True, # Enable torch.compile acceleration
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path="/path/to/lora_checkpoint",
|
||||
)
|
||||
|
||||
# 2. Load LoRA weights (works after torch.compile)
|
||||
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint")
|
||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
# 2. Generate audio
|
||||
audio = model.generate(
|
||||
text="Hello, this is LoRA fine-tuned result.",
|
||||
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
|
||||
prompt_text="Reference audio transcript", # Optional: for voice cloning
|
||||
)
|
||||
|
||||
# 3. Disable LoRA (use base model only)
|
||||
model.set_lora_enabled(False)
|
||||
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
|
||||
model.set_lora_enabled(True)
|
||||
|
||||
# 5. Unload LoRA (reset weights to zero)
|
||||
model.reset_lora_weights()
|
||||
model.unload_lora()
|
||||
|
||||
# 6. Hot-swap to another LoRA
|
||||
model.load_lora_weights("/path/to/another_lora_checkpoint")
|
||||
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
|
||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
|
||||
# 7. Get current LoRA weights
|
||||
lora_state = model.get_lora_state_dict()
|
||||
```
|
||||
|
||||
### Simplified Usage (Auto LoRA Config)
|
||||
|
||||
If you only have LoRA weights and don't need custom config, just provide the path:
|
||||
|
||||
```python
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5",
|
||||
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
||||
)
|
||||
```
|
||||
|
||||
### Method Reference
|
||||
|
||||
| Method | Description | torch.compile Compatible |
|
||||
|--------|-------------|--------------------------|
|
||||
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ |
|
||||
| `load_lora(path)` | Load LoRA weights from file | ✅ |
|
||||
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
|
||||
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ |
|
||||
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
|
||||
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
|
||||
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
|
||||
|
||||
---
|
||||
|
||||
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
|
||||
### 4. LoRA Not Taking Effect at Inference
|
||||
|
||||
- Ensure inference config matches training config LoRA parameters
|
||||
- Check `load_lora_weights` return value - `skipped_keys` should be empty
|
||||
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||
- Verify `set_lora_enabled(True)` is called
|
||||
|
||||
### 5. Checkpoint Loading Errors
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Full finetune inference script (no LoRA).
|
||||
|
||||
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
||||
can be loaded directly via VoxCPMModel.from_local().
|
||||
can be loaded directly via VoxCPM.
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -26,9 +26,8 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -81,49 +80,52 @@ def parse_args():
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory
|
||||
# Load model from checkpoint directory (no denoiser)
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=args.ckpt_dir,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
|
||||
# Squeeze and save audio
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio_np = audio.squeeze(0).cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
|
||||
# Save audio
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.sample_rate)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@@ -25,9 +25,8 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
@@ -88,6 +87,11 @@ def parse_args():
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -100,123 +104,114 @@ def main():
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Load base model (with LoRA structure and torch.compile)
|
||||
print(f"[1/3] Loading base model: {pretrained_path}")
|
||||
model = VoxCPMModel.from_local(
|
||||
pretrained_path,
|
||||
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
|
||||
training=False,
|
||||
lora_config=lora_cfg,
|
||||
)
|
||||
|
||||
# Debug: check DiT param paths after compile
|
||||
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
|
||||
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
|
||||
|
||||
# 3. Load LoRA weights (works after compile)
|
||||
ckpt_dir = Path(args.lora_ckpt)
|
||||
if not ckpt_dir.exists():
|
||||
# 2. Check LoRA checkpoint
|
||||
ckpt_dir = args.lora_ckpt
|
||||
if not Path(ckpt_dir).exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
print(f"[2/3] Loading LoRA weights: {ckpt_dir}")
|
||||
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
|
||||
print(f" Loaded {len(loaded)} parameters")
|
||||
if skipped:
|
||||
print(f"[WARNING] Skipped {len(skipped)} parameters")
|
||||
print(f" Skipped keys (first 5): {skipped[:5]}")
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=ckpt_dir,
|
||||
)
|
||||
|
||||
# 4. Synthesize audio
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[3/3] Starting synthesis tests...")
|
||||
print(f"\n[2/2] Starting synthesis tests...")
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
model.set_lora_enabled(False)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
model.set_lora_enabled(True)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...")
|
||||
model.reset_lora_weights()
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora_weights) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...")
|
||||
loaded, _ = model.load_lora_weights(str(ckpt_dir))
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
@@ -228,5 +223,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Build LoRA config if lora_path is provided
|
||||
lora_config = None
|
||||
lora_weights_path = getattr(args, "lora_path", None)
|
||||
if lora_weights_path:
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
||||
r=getattr(args, "lora_r", 32),
|
||||
alpha=getattr(args, "lora_alpha", 16),
|
||||
dropout=getattr(args, "lora_dropout", 0.0),
|
||||
)
|
||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
@@ -256,6 +276,15 @@ Examples:
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
# LoRA parameters
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
@@ -12,6 +12,8 @@ class VoxCPM:
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
@@ -23,9 +25,30 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
enable_proj=False,
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
||||
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
||||
|
||||
self.text_normalizer = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
@@ -46,6 +69,8 @@ class VoxCPM:
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
@@ -59,6 +84,12 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
|
||||
after model initialization.
|
||||
Kwargs:
|
||||
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||
|
||||
@@ -90,6 +121,8 @@ class VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -197,3 +230,51 @@ class VoxCPM:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Interface (delegated to VoxCPMModel)
|
||||
# ------------------------------------------------------------------ #
|
||||
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||
"""Load LoRA weights from a checkpoint file.
|
||||
|
||||
Args:
|
||||
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt).
|
||||
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model was not initialized with LoRA config.
|
||||
"""
|
||||
if self.tts_model.lora_config is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load LoRA weights: model was not initialized with LoRA config. "
|
||||
"Please reinitialize with lora_config or lora_weights_path parameter."
|
||||
)
|
||||
return self.tts_model.load_lora_weights(lora_weights_path)
|
||||
|
||||
def unload_lora(self):
|
||||
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||
self.tts_model.reset_lora_weights()
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable or disable LoRA layers without unloading weights.
|
||||
|
||||
Args:
|
||||
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||
"""
|
||||
self.tts_model.set_lora_enabled(enabled)
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get current LoRA parameters state dict.
|
||||
|
||||
Returns:
|
||||
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||
"""
|
||||
return self.tts_model.get_lora_state_dict()
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
"""Check if LoRA is currently configured."""
|
||||
return self.tts_model.lora_config is not None
|
||||
Reference in New Issue
Block a user