mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-13 04:08:11 +00:00
Modify lora inference api
This commit is contained in:
@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Build LoRA config if lora_path is provided
|
||||
lora_config = None
|
||||
lora_weights_path = getattr(args, "lora_path", None)
|
||||
if lora_weights_path:
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
||||
r=getattr(args, "lora_r", 32),
|
||||
alpha=getattr(args, "lora_alpha", 16),
|
||||
dropout=getattr(args, "lora_dropout", 0.0),
|
||||
)
|
||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
@@ -256,6 +276,15 @@ Examples:
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
# LoRA parameters
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
@@ -12,6 +12,8 @@ class VoxCPM:
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
@@ -23,9 +25,30 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
enable_proj=False,
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
||||
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
||||
|
||||
self.text_normalizer = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
@@ -46,6 +69,8 @@ class VoxCPM:
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
@@ -59,6 +84,12 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
|
||||
after model initialization.
|
||||
Kwargs:
|
||||
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||
|
||||
@@ -90,6 +121,8 @@ class VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -196,4 +229,52 @@ class VoxCPM:
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Interface (delegated to VoxCPMModel)
|
||||
# ------------------------------------------------------------------ #
|
||||
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||
"""Load LoRA weights from a checkpoint file.
|
||||
|
||||
Args:
|
||||
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt).
|
||||
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model was not initialized with LoRA config.
|
||||
"""
|
||||
if self.tts_model.lora_config is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load LoRA weights: model was not initialized with LoRA config. "
|
||||
"Please reinitialize with lora_config or lora_weights_path parameter."
|
||||
)
|
||||
return self.tts_model.load_lora_weights(lora_weights_path)
|
||||
|
||||
def unload_lora(self):
|
||||
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||
self.tts_model.reset_lora_weights()
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable or disable LoRA layers without unloading weights.
|
||||
|
||||
Args:
|
||||
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||
"""
|
||||
self.tts_model.set_lora_enabled(enabled)
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get current LoRA parameters state dict.
|
||||
|
||||
Returns:
|
||||
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||
"""
|
||||
return self.tts_model.get_lora_state_dict()
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
"""Check if LoRA is currently configured."""
|
||||
return self.tts_model.lora_config is not None
|
||||
Reference in New Issue
Block a user