add lora funetine webUI; optimize lora save and load logic

This commit is contained in:
刘鑫
2025-12-09 21:34:39 +08:00
parent 0779a93697
commit a266c0a88d
9 changed files with 1575 additions and 48 deletions

View File

@@ -5,7 +5,6 @@ LoRA inference test script.
Usage:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "Hello, this is LoRA finetuned result." \
--output lora_test.wav
@@ -13,37 +12,39 @@ Usage:
With voice cloning:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output lora_clone.wav
Note: The script reads base_model path and lora_config from lora_config.json
in the checkpoint directory (saved automatically during training).
"""
import argparse
import json
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Training YAML config path (contains pretrained_path and lora config)",
)
parser.add_argument(
"--lora_ckpt",
type=str,
required=True,
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
)
parser.add_argument(
"--base_model",
type=str,
default="",
help="Optional: override base model path (default: read from lora_config.json)",
)
parser.add_argument(
"--text",
@@ -98,26 +99,44 @@ def parse_args():
def main():
args = parse_args()
# 1. Load YAML config
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. Check LoRA checkpoint
ckpt_dir = args.lora_ckpt
if not Path(ckpt_dir).exists():
# 1. Check LoRA checkpoint directory
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
# 2. Load lora_config.json from checkpoint
lora_config_path = ckpt_dir / "lora_config.json"
if not lora_config_path.exists():
raise FileNotFoundError(
f"lora_config.json not found in {ckpt_dir}. "
"Make sure the checkpoint was saved with the updated training script."
)
with open(lora_config_path, "r", encoding="utf-8") as f:
lora_info = json.load(f)
# Get base model path (command line arg overrides config)
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
if not pretrained_path:
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
# Get LoRA config
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}")
print(f" Base model: {pretrained_path}")
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
# 3. Load model with LoRA (no denoiser)
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=True,
lora_config=lora_cfg,
lora_weights_path=ckpt_dir,
lora_weights_path=str(ckpt_dir),
)
# 4. Synthesize audio
@@ -197,7 +216,7 @@ def main():
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, skipped = model.load_lora(str(ckpt_dir))
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters")
audio_np = model.generate(
text=args.text,