mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
add lora funetine webUI; optimize lora save and load logic
This commit is contained in:
@@ -5,7 +5,6 @@ LoRA inference test script.
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "Hello, this is LoRA finetuned result." \
|
||||
--output lora_test.wav
|
||||
@@ -13,37 +12,39 @@ Usage:
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "This is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output lora_clone.wav
|
||||
|
||||
Note: The script reads base_model path and lora_config from lora_config.json
|
||||
in the checkpoint directory (saved automatically during training).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training YAML config path (contains pretrained_path and lora config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_ckpt",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
|
||||
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: override base model path (default: read from lora_config.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
@@ -98,26 +99,44 @@ def parse_args():
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# 1. Load YAML config
|
||||
cfg = load_yaml_config(args.config_path)
|
||||
pretrained_path = cfg["pretrained_path"]
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Check LoRA checkpoint
|
||||
ckpt_dir = args.lora_ckpt
|
||||
if not Path(ckpt_dir).exists():
|
||||
# 1. Check LoRA checkpoint directory
|
||||
ckpt_dir = Path(args.lora_ckpt)
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
# 2. Load lora_config.json from checkpoint
|
||||
lora_config_path = ckpt_dir / "lora_config.json"
|
||||
if not lora_config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"lora_config.json not found in {ckpt_dir}. "
|
||||
"Make sure the checkpoint was saved with the updated training script."
|
||||
)
|
||||
|
||||
with open(lora_config_path, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
|
||||
# Get base model path (command line arg overrides config)
|
||||
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
|
||||
if not pretrained_path:
|
||||
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
|
||||
|
||||
# Get LoRA config
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}")
|
||||
print(f" Base model: {pretrained_path}")
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=ckpt_dir,
|
||||
lora_weights_path=str(ckpt_dir),
|
||||
)
|
||||
|
||||
# 4. Synthesize audio
|
||||
@@ -197,7 +216,7 @@ def main():
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
|
||||
Reference in New Issue
Block a user