diff --git a/README.md b/README.md index 11ecbd6..5b5046c 100644 --- a/README.md +++ b/README.md @@ -210,6 +210,8 @@ We're excited to see the VoxCPM community growing! Here are some amazing project - **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU. - **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference. - **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server. +- **[PR: LoRA finetune web UI (by Ayin1412)](https://github.com/OpenBMB/VoxCPM/pull/100)** +- **[voxcpm_rs](https://github.com/madushan1000/voxcpm_rs)** A re-implementation of VoxCPM-0.5B in Rust. *Note: The projects are not officially maintained by OpenBMB.* diff --git a/app.py b/app.py index 96c92a5..a26a0df 100644 --- a/app.py +++ b/app.py @@ -267,7 +267,7 @@ def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error demo = VoxCPMDemo() interface = create_demo_interface(demo) # Recommended to enable queue on Spaces for better throughput - interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error) + interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error) if __name__ == "__main__": diff --git a/conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml b/conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml index 7f1187d..2b0ba54 100644 --- a/conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml +++ b/conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml @@ -19,6 +19,8 @@ tensorboard: /path/to/logs/finetune_lora lambdas: loss/diff: 1.0 loss/stop: 1.0 + +# LoRA configuration lora: enable_lm: true enable_dit: true @@ -26,3 +28,9 @@ lora: r: 32 alpha: 16 dropout: 0.0 + +# Distribution options (optional) +# - If distribute=false (default): save pretrained_path as base_model in lora_config.json +# - If distribute=true: save hf_model_id as base_model (hf_model_id is required) +# hf_model_id: "openbmb/VoxCPM1.5" +# distribute: true diff --git a/conf/voxcpm_v1/voxcpm_finetune_lora.yaml b/conf/voxcpm_v1/voxcpm_finetune_lora.yaml index 49e2b7d..642f115 100644 --- a/conf/voxcpm_v1/voxcpm_finetune_lora.yaml +++ b/conf/voxcpm_v1/voxcpm_finetune_lora.yaml @@ -19,10 +19,18 @@ tensorboard: /path/to/logs/finetune_lora lambdas: loss/diff: 1.0 loss/stop: 1.0 + +# LoRA configuration lora: enable_lm: true enable_dit: true enable_proj: false r: 32 alpha: 16 - dropout: 0.0 \ No newline at end of file + dropout: 0.0 + +# Distribution options (optional) +# - If distribute=false (default): save pretrained_path as base_model in lora_config.json +# - If distribute=true: save hf_model_id as base_model (hf_model_id is required) +# hf_model_id: "openbmb/VoxCPM-0.5B" +# distribute: true \ No newline at end of file diff --git a/docs/finetune.md b/docs/finetune.md index 70a24a3..398d94b 100644 --- a/docs/finetune.md +++ b/docs/finetune.md @@ -19,6 +19,7 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that: ## Table of Contents +- [Quick Start: WebUI](#quick-start-webui) - [Data Preparation](#data-preparation) - [Full Fine-tuning](#full-fine-tuning) - [LoRA Fine-tuning](#lora-fine-tuning) @@ -28,6 +29,31 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that: --- +## Quick Start: WebUI + +For users who prefer a graphical interface, we provide `lora_ft_webui.py` - a comprehensive WebUI for training and inference: + +### Launch WebUI + +```bash +python lora_ft_webui.py +``` + +Then open `http://localhost:7860` in your browser. + +### Features + +- **🚀 Training Tab**: Configure and start LoRA training with an intuitive interface + - Set training parameters (learning rate, batch size, LoRA rank, etc.) + - Monitor training progress in real-time + - Resume training from existing checkpoints + +- **🎵 Inference Tab**: Generate audio with trained models + - Automatic base model loading from LoRA checkpoint config + - Voice cloning with automatic ASR (reference text recognition) + - Hot-swap between multiple LoRA models + - Zero-shot TTS without reference audio + ## Data Preparation Training data should be prepared as a JSONL manifest file, with one sample per line: @@ -177,6 +203,10 @@ lora: # Target modules target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"] target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"] + +# Distribution options (optional) +# hf_model_id: "openbmb/VoxCPM1.5" # HuggingFace ID +# distribute: true # If true, save hf_model_id in lora_config.json ``` ### LoRA Parameters @@ -189,6 +219,15 @@ lora: | `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` | | `target_modules_*` | Layer names to add LoRA | attention layers | +### Distribution Options (Optional) + +| Parameter | Description | Default | +|-----------|-------------|---------| +| `hf_model_id` | HuggingFace model ID (e.g., `openbmb/VoxCPM1.5`) | `""` | +| `distribute` | If `true`, save `hf_model_id` as `base_model` in checkpoint; otherwise save local `pretrained_path` | `false` | + +> **Note**: If `distribute: true`, `hf_model_id` is required. + ### Training ```bash @@ -202,16 +241,37 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \ ### Checkpoint Structure -LoRA training saves only LoRA parameters: +LoRA training saves LoRA parameters and configuration: ``` checkpoints/finetune_lora/ └── step_0002000/ ├── lora_weights.safetensors # Only lora_A, lora_B parameters + ├── lora_config.json # LoRA config + base model path ├── optimizer.pth └── scheduler.pth ``` +The `lora_config.json` contains: +```json +{ + "base_model": "/path/to/VoxCPM1.5/", + "lora_config": { + "enable_lm": true, + "enable_dit": true, + "r": 32, + "alpha": 16, + ... + } +} +``` + +The `base_model` field contains: +- Local path (default): when `distribute: false` or not set +- HuggingFace ID: when `distribute: true` (e.g., `"openbmb/VoxCPM1.5"`) + +This allows loading LoRA checkpoints without the original training config file. + --- ## Inference @@ -240,11 +300,10 @@ python scripts/test_voxcpm_ft_infer.py \ ### LoRA Inference -LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint: +LoRA inference only requires the checkpoint directory (base model path and LoRA config are read from `lora_config.json`): ```bash python scripts/test_voxcpm_lora_infer.py \ - --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \ --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ --text "Hello, this is LoRA fine-tuned result." \ --output lora_output.wav @@ -254,7 +313,6 @@ With voice cloning: ```bash python scripts/test_voxcpm_lora_infer.py \ - --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \ --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ --text "This is voice cloning with LoRA." \ --prompt_audio /path/to/reference.wav \ @@ -262,6 +320,16 @@ python scripts/test_voxcpm_lora_infer.py \ --output cloned_output.wav ``` +Override base model path (optional): + +```bash +python scripts/test_voxcpm_lora_infer.py \ + --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ + --base_model /path/to/another/VoxCPM1.5 \ + --text "Use different base model." \ + --output output.wav +``` + --- ## LoRA Hot-swapping @@ -315,20 +383,39 @@ print(f"Loaded {len(loaded)} params, skipped {len(skipped)}") lora_state = model.get_lora_state_dict() ``` -### Simplified Usage (Auto LoRA Config) +### Simplified Usage (Load from lora_config.json) -If you only have LoRA weights and don't need custom config, just provide the path: +If your checkpoint contains `lora_config.json` (saved by the training script), you can load everything automatically: ```python +import json from voxcpm.core import VoxCPM +from voxcpm.model.voxcpm import LoRAConfig -# Auto-create default LoRAConfig when only lora_weights_path is provided +# Load config from checkpoint +lora_ckpt_dir = "/path/to/checkpoints/finetune_lora/step_0002000" +with open(f"{lora_ckpt_dir}/lora_config.json") as f: + lora_info = json.load(f) + +base_model = lora_info["base_model"] +lora_cfg = LoRAConfig(**lora_info["lora_config"]) + +# Load model with LoRA model = VoxCPM.from_pretrained( - hf_model_id="openbmb/VoxCPM1.5", - lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig + hf_model_id=base_model, + lora_config=lora_cfg, + lora_weights_path=lora_ckpt_dir, ) ``` +Or use the test script directly: + +```bash +python scripts/test_voxcpm_lora_infer.py \ + --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ + --text "Hello world" +``` + ### Method Reference | Method | Description | torch.compile Compatible | @@ -354,7 +441,6 @@ model = VoxCPM.from_pretrained( - Increase `r` (LoRA rank) - Adjust `alpha` (try `alpha = r/2` or `alpha = r`) -- Ensure `enable_dit: true` (required for voice cloning) - Increase training steps - Add more target modules @@ -366,11 +452,13 @@ model = VoxCPM.from_pretrained( ### 4. LoRA Not Taking Effect at Inference -- Ensure inference config matches training config LoRA parameters +- Check that `lora_config.json` exists in the checkpoint directory - Check `load_lora()` return value - `skipped_keys` should be empty - Verify `set_lora_enabled(True)` is called ### 5. Checkpoint Loading Errors -- Full fine-tuning: checkpoint directory should contain `model.safetensors`(or `pytorch_model.bin`), `config.json`, `audiovae.pth` -- LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`) +- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth` +- LoRA: checkpoint directory should contain: + - `lora_weights.safetensors` (or `lora_weights.ckpt`) - LoRA weights + - `lora_config.json` - LoRA config and base model path diff --git a/lora_ft_webui.py b/lora_ft_webui.py new file mode 100644 index 0000000..9f679e6 --- /dev/null +++ b/lora_ft_webui.py @@ -0,0 +1,1253 @@ +import os +import sys +import time +import glob +import json +import yaml +import shutil +import datetime +import subprocess +import threading +import gradio as gr +import torch +import soundfile as sf +from pathlib import Path +from typing import Optional, List + +# Add src to sys.path +project_root = Path(__file__).parent +sys.path.insert(0, str(project_root / "src")) + +# Default pretrained model path relative to this repo +default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5") + +from voxcpm.core import VoxCPM +from voxcpm.model.voxcpm import LoRAConfig +import numpy as np +from funasr import AutoModel + +# --- Localization --- +LANG_DICT = { + "en": { + "title": "VoxCPM LoRA WebUI", + "tab_train": "Training", + "tab_infer": "Inference", + "pretrained_path": "Pretrained Model Path", + "train_manifest": "Train Manifest (jsonl)", + "val_manifest": "Validation Manifest (Optional)", + "lr": "Learning Rate", + "max_iters": "Max Iterations", + "batch_size": "Batch Size", + "lora_rank": "LoRA Rank", + "lora_alpha": "LoRA Alpha", + "save_interval": "Save Interval", + "start_train": "Start Training", + "stop_train": "Stop Training", + "train_logs": "Training Logs", + "text_to_synth": "Text to Synthesize", + "voice_cloning": "### Voice Cloning (Optional)", + "ref_audio": "Reference Audio", + "ref_text": "Reference Text (Optional)", + "select_lora": "Select LoRA Checkpoint", + "cfg_scale": "CFG Scale", + "infer_steps": "Inference Steps", + "seed": "Seed", + "gen_audio": "Generate Audio", + "gen_output": "Generated Audio", + "status": "Status", + "lang_select": "Language / 语言", + "refresh": "Refresh", + "output_name": "Output Name (Optional, resume if exists)", + }, + "zh": { + "title": "VoxCPM LoRA WebUI", + "tab_train": "训练 (Training)", + "tab_infer": "推理 (Inference)", + "pretrained_path": "预训练模型路径", + "train_manifest": "训练数据清单 (jsonl)", + "val_manifest": "验证数据清单 (可选)", + "lr": "学习率 (Learning Rate)", + "max_iters": "最大迭代次数", + "batch_size": "批次大小 (Batch Size)", + "lora_rank": "LoRA Rank", + "lora_alpha": "LoRA Alpha", + "save_interval": "保存间隔 (Steps)", + "start_train": "开始训练", + "stop_train": "停止训练", + "train_logs": "训练日志", + "text_to_synth": "合成文本", + "voice_cloning": "### 声音克隆 (可选)", + "ref_audio": "参考音频", + "ref_text": "参考文本 (可选)", + "select_lora": "选择 LoRA 模型", + "cfg_scale": "CFG Scale (引导系数)", + "infer_steps": "推理步数", + "seed": "随机种子 (Seed)", + "gen_audio": "生成音频", + "gen_output": "生成结果", + "status": "状态", + "lang_select": "Language / 语言", + "refresh": "刷新", + "output_name": "输出目录名称 (可选,若存在则继续训练)", + } +} + +# Global variables +current_model: Optional[VoxCPM] = None +asr_model: Optional[AutoModel] = None +training_process: Optional[subprocess.Popen] = None +training_log = "" + +def get_timestamp_str(): + return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + +def get_or_load_asr_model(): + global asr_model + if asr_model is None: + print("Loading ASR model (SenseVoiceSmall)...") + device = "cuda:0" if torch.cuda.is_available() else "cpu" + asr_model = AutoModel( + model="iic/SenseVoiceSmall", + disable_update=True, + log_level='ERROR', + device=device, + ) + return asr_model + +def recognize_audio(audio_path): + if not audio_path: + return "" + try: + model = get_or_load_asr_model() + res = model.generate(input=audio_path, language="auto", use_itn=True) + text = res[0]["text"].split('|>')[-1] + return text + except Exception as e: + print(f"ASR Error: {e}") + return "" + +def scan_lora_checkpoints(root_dir="lora", with_info=False): + """ + Scans for LoRA checkpoints in the lora directory. + + Args: + root_dir: Directory to scan for LoRA checkpoints + with_info: If True, returns list of (path, base_model) tuples + + Returns: + List of checkpoint paths, or list of (path, base_model) tuples if with_info=True + """ + checkpoints = [] + if not os.path.exists(root_dir): + os.makedirs(root_dir, exist_ok=True) + + # Look for lora_weights.safetensors recursively + for root, dirs, files in os.walk(root_dir): + if "lora_weights.safetensors" in files: + # Use the relative path from root_dir as the ID + rel_path = os.path.relpath(root, root_dir) + + if with_info: + # Try to read base_model from lora_config.json + base_model = None + lora_config_file = os.path.join(root, "lora_config.json") + if os.path.exists(lora_config_file): + try: + with open(lora_config_file, "r", encoding="utf-8") as f: + lora_info = json.load(f) + base_model = lora_info.get("base_model", "Unknown") + except: + pass + checkpoints.append((rel_path, base_model)) + else: + checkpoints.append(rel_path) + + # Also check for checkpoints in the default location if they exist + default_ckpt = "checkpoints/finetune_lora" + if os.path.exists(os.path.join(root_dir, default_ckpt)): + # This might be covered by the walk, but good to be sure + pass + + return sorted(checkpoints, reverse=True) + +def load_lora_config_from_checkpoint(lora_path): + """Load LoRA config from lora_config.json if available.""" + lora_config_file = os.path.join(lora_path, "lora_config.json") + if os.path.exists(lora_config_file): + try: + with open(lora_config_file, "r", encoding="utf-8") as f: + lora_info = json.load(f) + lora_cfg_dict = lora_info.get("lora_config", {}) + if lora_cfg_dict: + return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model") + except Exception as e: + print(f"Warning: Failed to load lora_config.json: {e}") + return None, None + +def get_default_lora_config(): + """Return default LoRA config for hot-swapping support.""" + return LoRAConfig( + enable_lm=True, + enable_dit=True, + r=32, + alpha=16, + target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"], + target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"] + ) + +def load_model(pretrained_path, lora_path=None): + global current_model + print(f"Loading model from {pretrained_path}...") + + lora_config = None + lora_weights_path = None + + if lora_path: + full_lora_path = os.path.join("lora", lora_path) + if os.path.exists(full_lora_path): + lora_weights_path = full_lora_path + # Try to load LoRA config from lora_config.json + lora_config, _ = load_lora_config_from_checkpoint(full_lora_path) + if lora_config: + print(f"Loaded LoRA config from {full_lora_path}/lora_config.json") + else: + # Fallback to default config for old checkpoints + lora_config = get_default_lora_config() + print("Using default LoRA config (lora_config.json not found)") + + # Always init with a default LoRA config to allow hot-swapping later + if lora_config is None: + lora_config = get_default_lora_config() + + current_model = VoxCPM.from_pretrained( + hf_model_id=pretrained_path, + load_denoiser=False, + optimize=False, + lora_config=lora_config, + lora_weights_path=lora_weights_path, + ) + return "Model loaded successfully!" + +def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed): + global current_model + + # 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model + if current_model is None: + base_model_path = default_pretrained_path # 默认路径 + + # 如果选择了 LoRA,尝试从其 config 读取 base_model + if lora_selection and lora_selection != "None": + full_lora_path = os.path.join("lora", lora_selection) + lora_config_file = os.path.join(full_lora_path, "lora_config.json") + + if os.path.exists(lora_config_file): + try: + with open(lora_config_file, "r", encoding="utf-8") as f: + lora_info = json.load(f) + saved_base_model = lora_info.get("base_model") + + if saved_base_model: + # 优先使用保存的 base_model 路径 + if os.path.exists(saved_base_model): + base_model_path = saved_base_model + print(f"Using base model from LoRA config: {base_model_path}") + else: + print(f"Warning: Saved base_model path not found: {saved_base_model}") + print(f"Falling back to default: {base_model_path}") + except Exception as e: + print(f"Warning: Failed to read base_model from LoRA config: {e}") + + # 加载模型 + try: + print(f"Loading base model: {base_model_path}") + status_msg = load_model(base_model_path) + if lora_selection and lora_selection != "None": + print(f"Model loaded for LoRA: {lora_selection}") + except Exception as e: + error_msg = f"Failed to load model from {base_model_path}: {str(e)}" + print(error_msg) + return None, error_msg + + # Handle LoRA hot-swapping + if lora_selection and lora_selection != "None": + full_lora_path = os.path.join("lora", lora_selection) + print(f"Hot-loading LoRA: {full_lora_path}") + try: + current_model.load_lora(full_lora_path) + current_model.set_lora_enabled(True) + except Exception as e: + print(f"Error loading LoRA: {e}") + return None, f"Error loading LoRA: {e}" + else: + print("Disabling LoRA") + current_model.set_lora_enabled(False) + + if seed != -1: + torch.manual_seed(seed) + np.random.seed(seed) + + # 处理 prompt 参数:必须同时为 None 或同时有值 + final_prompt_wav = None + final_prompt_text = None + + if prompt_wav and prompt_wav.strip(): + # 有参考音频 + final_prompt_wav = prompt_wav + + # 如果没有提供参考文本,尝试自动识别 + if not prompt_text or not prompt_text.strip(): + print("参考音频已提供但缺少文本,自动识别中...") + try: + final_prompt_text = recognize_audio(prompt_wav) + if final_prompt_text: + print(f"自动识别文本: {final_prompt_text}") + else: + return None, "错误:无法识别参考音频内容,请手动填写参考文本" + except Exception as e: + return None, f"错误:自动识别参考音频失败 - {str(e)}" + else: + final_prompt_text = prompt_text.strip() + # 如果没有参考音频,两个都设为 None(用于零样本 TTS) + + try: + audio_np = current_model.generate( + text=text, + prompt_wav_path=final_prompt_wav, + prompt_text=final_prompt_text, + cfg_value=cfg_scale, + inference_timesteps=steps, + denoise=False + ) + return (current_model.tts_model.sample_rate, audio_np), "Generation Success" + except Exception as e: + import traceback + traceback.print_exc() + return None, f"Error: {str(e)}" + +def start_training( + pretrained_path, + train_manifest, + val_manifest, + learning_rate, + num_iters, + batch_size, + lora_rank, + lora_alpha, + save_interval, + output_name="", + # Advanced options + grad_accum_steps=1, + num_workers=2, + log_interval=10, + valid_interval=1000, + weight_decay=0.01, + warmup_steps=100, + max_steps=None, + sample_rate=44100, + # LoRA advanced + enable_lm=True, + enable_dit=True, + enable_proj=False, + dropout=0.0, + tensorboard_path="", + # Distribution options + hf_model_id="", + distribute=False, +): + global training_process, training_log + + if training_process is not None and training_process.poll() is None: + return "Training is already running!" + + if output_name and output_name.strip(): + timestamp = output_name.strip() + else: + timestamp = get_timestamp_str() + + save_dir = os.path.join("lora", timestamp) + checkpoints_dir = os.path.join(save_dir, "checkpoints") + logs_dir = os.path.join(save_dir, "logs") + + os.makedirs(checkpoints_dir, exist_ok=True) + os.makedirs(logs_dir, exist_ok=True) + + # Create config dictionary + # Resolve max_steps default + resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters) + + config = { + "pretrained_path": pretrained_path, + "train_manifest": train_manifest, + "val_manifest": val_manifest, + "sample_rate": int(sample_rate), + "batch_size": int(batch_size), + "grad_accum_steps": int(grad_accum_steps), + "num_workers": int(num_workers), + "num_iters": int(num_iters), + "log_interval": int(log_interval), + "valid_interval": int(valid_interval), + "save_interval": int(save_interval), + "learning_rate": float(learning_rate), + "weight_decay": float(weight_decay), + "warmup_steps": int(warmup_steps), + "max_steps": resolved_max_steps, + "save_path": checkpoints_dir, + "tensorboard": tensorboard_path if tensorboard_path else logs_dir, + "lambdas": { + "loss/diff": 1.0, + "loss/stop": 1.0 + }, + "lora": { + "enable_lm": bool(enable_lm), + "enable_dit": bool(enable_dit), + "enable_proj": bool(enable_proj), + "r": int(lora_rank), + "alpha": int(lora_alpha), + "dropout": float(dropout), + "target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"], + "target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"] + }, + } + + # Add distribution options if provided + if hf_model_id and hf_model_id.strip(): + config["hf_model_id"] = hf_model_id.strip() + if distribute: + config["distribute"] = True + + config_path = os.path.join(save_dir, "train_config.yaml") + with open(config_path, "w") as f: + yaml.dump(config, f) + + cmd = [ + sys.executable, + "scripts/train_voxcpm_finetune.py", + "--config_path", + config_path + ] + + training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n" + + def run_process(): + global training_process, training_log + training_process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1 + ) + + for line in training_process.stdout: + training_log += line + # Keep log size manageable + if len(training_log) > 100000: + training_log = training_log[-100000:] + + training_process.wait() + training_log += f"\nTraining finished with code {training_process.returncode}" + + threading.Thread(target=run_process, daemon=True).start() + + return f"Training started! Check 'lora/{timestamp}'" + +def get_training_log(): + return training_log + +def stop_training(): + global training_process, training_log + if training_process is not None and training_process.poll() is None: + training_process.terminate() + training_log += "\nTraining terminated by user." + return "Training stopped." + return "No training running." + +# --- GUI Layout --- + +# 自定义CSS样式 +custom_css = """ +/* 整体主题样式 */ +.gradio-container { + background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); + font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; +} + +/* 标题区域样式 - 扁平化设计 */ +.title-section { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + border-radius: 8px; + padding: 15px 25px; + margin-bottom: 15px; + border: none; + box-shadow: 0 2px 8px rgba(0,0,0,0.1); +} + +.title-section h1 { + color: white; + text-shadow: none; + font-weight: 600; + margin: 0; + font-size: 28px; + line-height: 1.2; +} + +.title-section h3 { + color: rgba(255, 255, 255, 0.9); + font-weight: 400; + margin-top: 5px; + font-size: 14px; + line-height: 1.3; +} + +.title-section p { + color: rgba(255, 255, 255, 0.85); + font-size: 13px; + margin: 5px 0 0 0; + line-height: 1.3; +} + +/* 标签页样式 */ +.tabs { + background: white; + border-radius: 15px; + padding: 10px; + box-shadow: 0 4px 20px rgba(0,0,0,0.08); +} + +/* 按钮样式增强 */ +.button-primary { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + border: none; + border-radius: 12px; + padding: 12px 30px; + font-weight: 600; + color: white; + transition: all 0.3s ease; + box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3); +} + +.button-primary:hover { + transform: translateY(-2px); + box-shadow: 0 6px 25px rgba(102, 126, 234, 0.4); +} + +.button-stop { + background: linear-gradient(135deg, #fa709a 0%, #fee140 100%); + border: none; + border-radius: 12px; + padding: 12px 30px; + font-weight: 600; + color: white; + transition: all 0.3s ease; + box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3); +} + +.button-stop:hover { + transform: translateY(-2px); + box-shadow: 0 6px 25px rgba(250, 112, 154, 0.4); +} + +.button-refresh { + background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%); + border: none; + border-radius: 10px; + padding: 8px 20px; + font-weight: 500; + color: white; + transition: all 0.3s ease; + box-shadow: 0 2px 10px rgba(132, 250, 176, 0.3); +} + +.button-refresh:hover { + transform: translateY(-1px); + box-shadow: 0 4px 15px rgba(132, 250, 176, 0.4); +} + +/* 表单区域样式 */ +.form-section { + background: white; + border-radius: 20px; + padding: 30px; + margin: 15px 0; + box-shadow: 0 8px 30px rgba(0,0,0,0.08); + border: 1px solid rgba(0,0,0,0.05); +} + +/* 输入框样式 */ +.input-field { + border-radius: 12px; + border: 2px solid #e0e0e0; + padding: 12px 16px; + transition: all 0.3s ease; + background: #fafafa; +} + +.input-field:focus { + border-color: #667eea; + box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.1); + background: white; +} + +/* 滑块样式 */ +.slider { + -webkit-appearance: none; + appearance: none; + width: 100%; + height: 6px; + border-radius: 3px; + background: linear-gradient(90deg, #667eea, #764ba2); + outline: none; + opacity: 0.8; + transition: opacity 0.2s; +} + +.slider:hover { + opacity: 1; +} + +.slider::-webkit-slider-thumb { + -webkit-appearance: none; + appearance: none; + width: 18px; + height: 18px; + border-radius: 50%; + background: white; + cursor: pointer; + border: 3px solid #667eea; + box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3); +} + +.slider::-moz-range-thumb { + width: 18px; + height: 18px; + border-radius: 50%; + background: white; + cursor: pointer; + border: 3px solid #667eea; + box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3); +} + +/* 折叠面板样式 */ +.accordion { + border-radius: 12px; + border: 2px solid #e0e0e0; + overflow: hidden; + background: white; +} + +.accordion-header { + background: linear-gradient(135deg, #f5f7fa 0%, #e3e7ed 100%); + padding: 15px 20px; + font-weight: 600; + color: #333; +} + +/* 状态显示样式 */ +.status-success { + background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%); + color: white; + padding: 12px 20px; + border-radius: 12px; + font-weight: 500; + box-shadow: 0 4px 15px rgba(132, 250, 176, 0.3); +} + +.status-error { + background: linear-gradient(135deg, #fa709a 0%, #fee140 100%); + color: white; + padding: 12px 20px; + border-radius: 12px; + font-weight: 500; + box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3); +} + +/* 语言切换按钮样式 - 扁平化 */ +.lang-selector { + background: rgba(255, 255, 255, 0.25); + backdrop-filter: blur(10px); + border-radius: 8px; + padding: 8px 12px; + border: 1px solid rgba(255, 255, 255, 0.4); +} + +.lang-selector label.gr-box { + color: white !important; + font-weight: 600; + margin-bottom: 8px !important; +} + +/* 单选按钮组样式 */ +.lang-selector fieldset, +.lang-selector .gr-form { + gap: 10px !important; + display: flex !important; +} + +/* 单选按钮容器 - 扁平化 (未选中状态 - 较浅的深色) */ +.lang-selector label.gr-radio-label { + background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important; + border: 1px solid rgba(255, 255, 255, 0.5) !important; + border-radius: 6px !important; + padding: 8px 18px !important; + color: white !important; + font-weight: 500 !important; + transition: all 0.2s ease !important; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; + cursor: pointer !important; + margin: 0 4px !important; +} + +/* 选中的单选按钮 - 扁平化 (更深的深色背景) */ +.lang-selector input[type="radio"]:checked + label, +.lang-selector label.gr-radio-label:has(input:checked) { + background: linear-gradient(135deg, #5568d3, #6b4c9a) !important; + color: white !important; + border: 1px solid rgba(255, 255, 255, 0.6) !important; + font-weight: 600 !important; + box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important; + transform: none !important; +} + +/* 未选中的单选按钮悬停效果 - 扁平化 */ +.lang-selector label.gr-radio-label:hover { + background: linear-gradient(135deg, rgba(102, 126, 234, 0.75), rgba(118, 75, 162, 0.75)) !important; + border-color: rgba(255, 255, 255, 0.7) !important; + transform: none !important; + box-shadow: 0 2px 10px rgba(0, 0, 0, 0.15) !important; +} + +/* 隐藏原始的单选按钮圆点 */ +.lang-selector input[type="radio"] { + opacity: 0; + position: absolute; +} + +/* Gradio Radio 特定样式 - 扁平化 */ +.lang-selector .wrap { + gap: 8px !important; +} + +.lang-selector .wrap > label { + background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important; + border: 1px solid rgba(255, 255, 255, 0.5) !important; + border-radius: 6px !important; + padding: 8px 18px !important; + color: white !important; + font-weight: 500 !important; + transition: all 0.2s ease !important; + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important; +} + +.lang-selector .wrap > label.selected { + background: linear-gradient(135deg, #5568d3, #6b4c9a) !important; + color: white !important; + border: 1px solid rgba(255, 255, 255, 0.6) !important; + font-weight: 600 !important; + box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important; +} + +/* 标签样式优化 */ +label { + color: #333; + font-weight: 500; + margin-bottom: 8px; +} + +/* Markdown 标题样式 */ +.markdown-text h4 { + color: #667eea; + font-weight: 600; + margin-top: 15px; + margin-bottom: 10px; +} + +/* 参数组件间距优化 */ +.form-section > div { + margin-bottom: 15px; +} + +/* Slider 组件样式优化 */ +.gr-slider { + padding: 10px 0; +} + +/* Number 输入框优化 */ +.gr-number { + max-width: 100%; +} + +/* 按钮容器优化 */ +.gr-button { + min-height: 45px; + font-size: 16px; +} + +/* 三栏布局优化 */ +#component-0 .gr-row { + gap: 20px; +} + +/* 生成按钮特殊样式 */ +.button-primary.gr-button-lg { + min-height: 55px; + font-size: 18px; + font-weight: 700; + margin-top: 20px; + margin-bottom: 10px; +} + +/* 刷新按钮小尺寸 */ +.button-refresh.gr-button-sm { + min-height: 38px; + font-size: 14px; + margin-top: 5px; + margin-bottom: 15px; +} + +/* 信息提示文字样式 */ +.gr-info { + font-size: 13px; + color: #666; + margin-top: 5px; +} + +/* 区域标题样式优化 */ +.form-section h4 { + color: #667eea; + font-weight: 600; + margin-top: 0; + margin-bottom: 15px; + padding-bottom: 10px; + border-bottom: 2px solid #f0f0f0; +} + +.form-section strong { + color: #667eea; + font-size: 15px; + display: block; + margin: 15px 0 10px 0; +} +""" + +with gr.Blocks( + title="VoxCPM LoRA WebUI", + theme=gr.themes.Soft(), + css=custom_css +) as app: + + # State for language + lang_state = gr.State("zh") # Default to Chinese + + # 标题区域 + with gr.Row(elem_classes="title-section"): + with gr.Column(scale=3): + title_md = gr.Markdown(""" + # 🎵 VoxCPM LoRA WebUI + ### 强大的语音合成和 LoRA 微调工具 + + 支持语音克隆、LoRA 模型训练和推理的完整解决方案 + """) + with gr.Column(scale=1): + lang_btn = gr.Radio( + choices=["en", "zh"], + value="zh", + label="🌐 Language / 语言", + elem_classes="lang-selector" + ) + + with gr.Tabs(elem_classes="tabs") as tabs: + # === Training Tab === + with gr.Tab("🚀 训练 (Training)") as tab_train: + gr.Markdown(""" + ### 🎯 模型训练设置 + 配置你的 LoRA 微调训练参数 + """) + + with gr.Row(): + with gr.Column(scale=2, elem_classes="form-section"): + gr.Markdown("#### 📁 基础配置") + + train_pretrained_path = gr.Textbox( + label="📂 预训练模型路径", + value=default_pretrained_path, + elem_classes="input-field" + ) + train_manifest = gr.Textbox( + label="📋 训练数据清单 (jsonl)", + value="examples/train_data_example.jsonl", + elem_classes="input-field" + ) + val_manifest = gr.Textbox( + label="📊 验证数据清单 (可选)", + value="", + elem_classes="input-field" + ) + + gr.Markdown("#### ⚙️ 训练参数") + + with gr.Row(): + lr = gr.Number( + label="📈 学习率 (Learning Rate)", + value=1e-4, + elem_classes="input-field" + ) + num_iters = gr.Number( + label="🔄 最大迭代次数", + value=2000, + precision=0, + elem_classes="input-field" + ) + batch_size = gr.Number( + label="📦 批次大小 (Batch Size)", + value=1, + precision=0, + elem_classes="input-field" + ) + + with gr.Row(): + lora_rank = gr.Number( + label="🎯 LoRA Rank", + value=32, + precision=0, + elem_classes="input-field" + ) + lora_alpha = gr.Number( + label="⚖️ LoRA Alpha", + value=16, + precision=0, + elem_classes="input-field" + ) + save_interval = gr.Number( + label="💾 保存间隔 (Steps)", + value=1000, + precision=0, + elem_classes="input-field" + ) + + output_name = gr.Textbox( + label="📁 输出目录名称 (可选,若存在则继续训练)", + value="", + elem_classes="input-field" + ) + + with gr.Row(): + start_btn = gr.Button( + "▶️ 开始训练", + variant="primary", + elem_classes="button-primary" + ) + stop_btn = gr.Button( + "⏹️ 停止训练", + variant="stop", + elem_classes="button-stop" + ) + + with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"): + with gr.Row(): + grad_accum_steps = gr.Number(label="梯度累积 (grad_accum_steps)", value=1, precision=0) + num_workers = gr.Number(label="数据加载线程 (num_workers)", value=2, precision=0) + log_interval = gr.Number(label="日志间隔 (log_interval)", value=10, precision=0) + with gr.Row(): + valid_interval = gr.Number(label="验证间隔 (valid_interval)", value=1000, precision=0) + weight_decay = gr.Number(label="权重衰减 (weight_decay)", value=0.01) + warmup_steps = gr.Number(label="warmup_steps", value=100, precision=0) + with gr.Row(): + max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0) + sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0) + tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="") + with gr.Row(): + enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True) + enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True) + enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False) + dropout = gr.Number(label="LoRA Dropout", value=0.0) + + gr.Markdown("#### 分发选项 (Distribution)") + with gr.Row(): + hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5") + distribute = gr.Checkbox(label="分发模式 (distribute)", value=False) + + with gr.Column(scale=2, elem_classes="form-section"): + gr.Markdown("#### 📊 训练日志") + logs_out = gr.TextArea( + label="", + lines=20, + max_lines=30, + interactive=False, + elem_classes="input-field", + show_label=False + ) + + start_btn.click( + start_training, + inputs=[ + train_pretrained_path, train_manifest, val_manifest, + lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval, + output_name, + # advanced + grad_accum_steps, num_workers, log_interval, valid_interval, + weight_decay, warmup_steps, max_steps, sample_rate, + enable_lm, enable_dit, enable_proj, dropout, tensorboard_path, + # distribution + hf_model_id, distribute + ], + outputs=[logs_out] # Initial message + ) + stop_btn.click(stop_training, outputs=[logs_out]) + + # Log refresher + timer = gr.Timer(1) + timer.tick(get_training_log, outputs=logs_out) + + # === Inference Tab === + with gr.Tab("🎵 推理 (Inference)") as tab_infer: + gr.Markdown(""" + ### 🎤 语音合成 + 使用训练好的 LoRA 模型生成语音,支持 LoRA 微调和声音克隆 + """) + + with gr.Row(): + # 左栏:输入配置 (35%) + with gr.Column(scale=35, elem_classes="form-section"): + gr.Markdown("#### 📝 输入配置") + + infer_text = gr.TextArea( + label="💬 合成文本", + value="Hello, this is a test of the VoxCPM LoRA model.", + elem_classes="input-field", + lines=4, + placeholder="输入要合成的文本内容..." + ) + + gr.Markdown("**🎭 声音克隆(可选)**") + + prompt_wav = gr.Audio( + label="🎵 参考音频", + type="filepath", + elem_classes="input-field" + ) + + prompt_text = gr.Textbox( + label="📝 参考文本(可选)", + elem_classes="input-field", + placeholder="如不填写,将自动识别参考音频内容" + ) + + # 中栏:模型选择和参数配置 (35%) + with gr.Column(scale=35, elem_classes="form-section"): + gr.Markdown("#### 🤖 模型选择") + + lora_select = gr.Dropdown( + label="🎯 LoRA 模型", + choices=["None"] + scan_lora_checkpoints(), + value="None", + interactive=True, + elem_classes="input-field", + info="选择训练好的 LoRA 模型,或选择 None 使用基础模型" + ) + + refresh_lora_btn = gr.Button( + "🔄 刷新模型列表", + elem_classes="button-refresh", + size="sm" + ) + + gr.Markdown("#### ⚙️ 生成参数") + + cfg_scale = gr.Slider( + label="🎛️ CFG Scale", + minimum=1.0, + maximum=5.0, + value=2.0, + step=0.1, + info="引导系数,值越大越贴近提示" + ) + + steps = gr.Slider( + label="🔢 推理步数", + minimum=1, + maximum=50, + value=10, + step=1, + info="生成质量与步数成正比,但耗时更长" + ) + + seed = gr.Number( + label="🎲 随机种子", + value=-1, + precision=0, + elem_classes="input-field", + info="-1 为随机,固定值可复现结果" + ) + + generate_btn = gr.Button( + "🎵 生成音频", + variant="primary", + elem_classes="button-primary", + size="lg" + ) + + # 右栏:生成结果 (30%) + with gr.Column(scale=30, elem_classes="form-section"): + gr.Markdown("#### 🎧 生成结果") + + audio_out = gr.Audio( + label="", + elem_classes="input-field", + show_label=False + ) + + gr.Markdown("#### 📋 状态信息") + + status_out = gr.Textbox( + label="", + interactive=False, + elem_classes="input-field", + show_label=False, + lines=3, + placeholder="等待生成..." + ) + + def refresh_loras(): + # 获取 LoRA checkpoints 及其 base model 信息 + checkpoints_with_info = scan_lora_checkpoints(with_info=True) + choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info] + + # 输出调试信息 + print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点") + for ckpt_path, base_model in checkpoints_with_info: + if base_model: + print(f" - {ckpt_path} (Base Model: {base_model})") + else: + print(f" - {ckpt_path}") + + return gr.update(choices=choices, value="None") + + refresh_lora_btn.click(refresh_loras, outputs=[lora_select]) + + # Auto-recognize audio when uploaded + prompt_wav.change( + fn=recognize_audio, + inputs=[prompt_wav], + outputs=[prompt_text] + ) + + generate_btn.click( + run_inference, + inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed], + outputs=[audio_out, status_out] + ) + + # --- Language Switching Logic --- + def change_language(lang): + d = LANG_DICT[lang] + # Labels for advanced options + if lang == "zh": + adv = { + 'grad_accum_steps': "梯度累积 (grad_accum_steps)", + 'num_workers': "数据加载线程 (num_workers)", + 'log_interval': "日志间隔 (log_interval)", + 'valid_interval': "验证间隔 (valid_interval)", + 'weight_decay': "权重衰减 (weight_decay)", + 'warmup_steps': "warmup_steps", + 'max_steps': "最大步数 (max_steps)", + 'sample_rate': "采样率 (sample_rate)", + 'enable_lm': "启用 LoRA LM (enable_lm)", + 'enable_dit': "启用 LoRA DIT (enable_dit)", + 'enable_proj': "启用投影 (enable_proj)", + 'dropout': "LoRA Dropout", + 'tensorboard_path': "Tensorboard 路径 (可选)", + 'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", + 'distribute': "分发模式 (distribute)", + } + else: + adv = { + 'grad_accum_steps': "Grad Accum Steps", + 'num_workers': "Num Workers", + 'log_interval': "Log Interval", + 'valid_interval': "Valid Interval", + 'weight_decay': "Weight Decay", + 'warmup_steps': "Warmup Steps", + 'max_steps': "Max Steps", + 'sample_rate': "Sample Rate", + 'enable_lm': "Enable LoRA LM", + 'enable_dit': "Enable LoRA DIT", + 'enable_proj': "Enable Projection", + 'dropout': "LoRA Dropout", + 'tensorboard_path': "Tensorboard Path (Optional)", + 'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", + 'distribute': "Distribute Mode", + } + + return ( + gr.update(value=f"# {d['title']}"), + gr.update(label=d['tab_train']), + gr.update(label=d['tab_infer']), + gr.update(label=d['pretrained_path']), + gr.update(label=d['train_manifest']), + gr.update(label=d['val_manifest']), + gr.update(label=d['lr']), + gr.update(label=d['max_iters']), + gr.update(label=d['batch_size']), + gr.update(label=d['lora_rank']), + gr.update(label=d['lora_alpha']), + gr.update(label=d['save_interval']), + gr.update(label=d['output_name']), + gr.update(value=d['start_train']), + gr.update(value=d['stop_train']), + gr.update(label=d['train_logs']), + # Advanced options (must match outputs order) + gr.update(label=adv['grad_accum_steps']), + gr.update(label=adv['num_workers']), + gr.update(label=adv['log_interval']), + gr.update(label=adv['valid_interval']), + gr.update(label=adv['weight_decay']), + gr.update(label=adv['warmup_steps']), + gr.update(label=adv['max_steps']), + gr.update(label=adv['sample_rate']), + gr.update(label=adv['enable_lm']), + gr.update(label=adv['enable_dit']), + gr.update(label=adv['enable_proj']), + gr.update(label=adv['dropout']), + gr.update(label=adv['tensorboard_path']), + # Distribution options + gr.update(label=adv['hf_model_id']), + gr.update(label=adv['distribute']), + # Inference section + gr.update(label=d['text_to_synth']), + gr.update(label=d['ref_audio']), + gr.update(label=d['ref_text']), + gr.update(label=d['select_lora']), + gr.update(value=d['refresh']), + gr.update(label=d['cfg_scale']), + gr.update(label=d['infer_steps']), + gr.update(label=d['seed']), + gr.update(value=d['gen_audio']), + gr.update(label=d['gen_output']), + gr.update(label=d['status']), + ) + + lang_btn.change( + change_language, + inputs=[lang_btn], + outputs=[ + title_md, tab_train, tab_infer, + train_pretrained_path, train_manifest, val_manifest, + lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval, + output_name, + start_btn, stop_btn, logs_out, + # advanced outputs + grad_accum_steps, num_workers, log_interval, valid_interval, + weight_decay, warmup_steps, max_steps, sample_rate, + enable_lm, enable_dit, enable_proj, dropout, tensorboard_path, + # distribution outputs + hf_model_id, distribute, + infer_text, prompt_wav, prompt_text, + lora_select, refresh_lora_btn, cfg_scale, steps, seed, + generate_btn, audio_out, status_out + ] + ) + +if __name__ == "__main__": + # Ensure lora directory exists + os.makedirs("lora", exist_ok=True) + app.queue().launch(server_name="0.0.0.0", server_port=7860) \ No newline at end of file diff --git a/scripts/test_voxcpm_lora_infer.py b/scripts/test_voxcpm_lora_infer.py index 84e0c6d..36dc414 100644 --- a/scripts/test_voxcpm_lora_infer.py +++ b/scripts/test_voxcpm_lora_infer.py @@ -5,7 +5,6 @@ LoRA inference test script. Usage: python scripts/test_voxcpm_lora_infer.py \ - --config_path conf/voxcpm/voxcpm_finetune_test.yaml \ --lora_ckpt checkpoints/step_0002000 \ --text "Hello, this is LoRA finetuned result." \ --output lora_test.wav @@ -13,37 +12,39 @@ Usage: With voice cloning: python scripts/test_voxcpm_lora_infer.py \ - --config_path conf/voxcpm/voxcpm_finetune_test.yaml \ --lora_ckpt checkpoints/step_0002000 \ --text "This is voice cloning result." \ --prompt_audio path/to/ref.wav \ --prompt_text "Reference audio transcript" \ --output lora_clone.wav + +Note: The script reads base_model path and lora_config from lora_config.json + in the checkpoint directory (saved automatically during training). """ import argparse +import json from pathlib import Path import soundfile as sf from voxcpm.core import VoxCPM from voxcpm.model.voxcpm import LoRAConfig -from voxcpm.training.config import load_yaml_config def parse_args(): parser = argparse.ArgumentParser("VoxCPM LoRA inference test") - parser.add_argument( - "--config_path", - type=str, - required=True, - help="Training YAML config path (contains pretrained_path and lora config)", - ) parser.add_argument( "--lora_ckpt", type=str, required=True, - help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)", + help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)", + ) + parser.add_argument( + "--base_model", + type=str, + default="", + help="Optional: override base model path (default: read from lora_config.json)", ) parser.add_argument( "--text", @@ -98,26 +99,44 @@ def parse_args(): def main(): args = parse_args() - # 1. Load YAML config - cfg = load_yaml_config(args.config_path) - pretrained_path = cfg["pretrained_path"] - lora_cfg_dict = cfg.get("lora", {}) or {} - lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None - - # 2. Check LoRA checkpoint - ckpt_dir = args.lora_ckpt - if not Path(ckpt_dir).exists(): + # 1. Check LoRA checkpoint directory + ckpt_dir = Path(args.lora_ckpt) + if not ckpt_dir.exists(): raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}") + # 2. Load lora_config.json from checkpoint + lora_config_path = ckpt_dir / "lora_config.json" + if not lora_config_path.exists(): + raise FileNotFoundError( + f"lora_config.json not found in {ckpt_dir}. " + "Make sure the checkpoint was saved with the updated training script." + ) + + with open(lora_config_path, "r", encoding="utf-8") as f: + lora_info = json.load(f) + + # Get base model path (command line arg overrides config) + pretrained_path = args.base_model if args.base_model else lora_info.get("base_model") + if not pretrained_path: + raise ValueError("base_model not found in lora_config.json and --base_model not provided") + + # Get LoRA config + lora_cfg_dict = lora_info.get("lora_config", {}) + lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None + + print(f"Loaded config from: {lora_config_path}") + print(f" Base model: {pretrained_path}") + print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None") + # 3. Load model with LoRA (no denoiser) - print(f"[1/2] Loading model with LoRA: {pretrained_path}") + print(f"\n[1/2] Loading model with LoRA: {pretrained_path}") print(f" LoRA weights: {ckpt_dir}") model = VoxCPM.from_pretrained( hf_model_id=pretrained_path, load_denoiser=False, optimize=True, lora_config=lora_cfg, - lora_weights_path=ckpt_dir, + lora_weights_path=str(ckpt_dir), ) # 4. Synthesize audio @@ -197,7 +216,7 @@ def main(): # === Test 5: Hot-reload LoRA (load_lora) === print(f"\n [Test 5] Hot-reload LoRA (load_lora)...") - loaded, skipped = model.load_lora(str(ckpt_dir)) + loaded, skipped = model.load_lora(ckpt_dir) print(f" Reloaded {len(loaded)} parameters") audio_np = model.generate( text=args.text, diff --git a/scripts/train_voxcpm_finetune.py b/scripts/train_voxcpm_finetune.py index c89ad5d..e17e46f 100644 --- a/scripts/train_voxcpm_finetune.py +++ b/scripts/train_voxcpm_finetune.py @@ -14,6 +14,8 @@ import torch from tensorboardX import SummaryWriter from torch.optim import AdamW from transformers import get_cosine_schedule_with_warmup +import signal +import os try: from safetensors.torch import save_file @@ -56,8 +58,16 @@ def train( lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0}, lora: dict = None, config_path: str = "", + # Distribution options (for LoRA checkpoints) + hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5") + distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path ): _ = config_path + + # Validate distribution options + if lora is not None and distribute and not hf_model_id: + raise ValueError("hf_model_id is required when distribute=True") + accelerator = Accelerator(amp=True) save_dir = Path(save_path) @@ -171,6 +181,39 @@ def train( num_training_steps=total_training_steps, ) + # Try to load checkpoint and resume training + start_step = 0 + if accelerator.rank == 0: + start_step = load_checkpoint(model, optimizer, scheduler, save_dir) + # Broadcast start_step to all processes + if hasattr(accelerator, 'all_reduce'): + start_step_tensor = torch.tensor(start_step, device=accelerator.device) + accelerator.all_reduce(start_step_tensor) + start_step = int(start_step_tensor.item()) + + if start_step > 0 and accelerator.rank == 0: + tracker.print(f"Resuming training from step {start_step}") + + # Resume tracker for signal handler to read current step + resume = {"step": start_step} + + # Register signal handler to save checkpoint on termination (SIGTERM/SIGINT) + def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume): + try: + cur_step = int(_resume.get("step", start_step)) + except Exception: + cur_step = start_step + print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...") + try: + save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist) + print("Checkpoint saved. Exiting.") + except Exception as e: + print(f"Error saving checkpoint on signal: {e}") + os._exit(0) + + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + # Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch() grad_accum_steps = max(int(grad_accum_steps), 1) data_epoch = 0 @@ -191,7 +234,9 @@ def train( return next(train_iter) with tracker.live(): - for step in range(num_iters): + for step in range(start_step, num_iters): + # update resume step so signal handler can save current progress + resume["step"] = step tracker.step = step optimizer.zero_grad(set_to_none=True) @@ -255,10 +300,10 @@ def train( validate(model, val_loader, batch_processor, accelerator, tracker, lambdas) if step % save_interval == 0 and accelerator.rank == 0: - save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path) + save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute) if accelerator.rank == 0: - save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path) + save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute) if writer: writer.close() @@ -301,7 +346,77 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas): model.train() -def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None): +def load_checkpoint(model, optimizer, scheduler, save_dir: Path): + """ + Load the latest checkpoint if it exists. + Returns the step number to resume from, or 0 if no checkpoint found. + """ + latest_folder = save_dir / "latest" + if not latest_folder.exists(): + return 0 + + unwrapped = model.module if hasattr(model, "module") else model + lora_cfg = unwrapped.lora_config + + # Load model weights + if lora_cfg is not None: + # LoRA: load lora_weights + lora_weights_path = latest_folder / "lora_weights.safetensors" + if not lora_weights_path.exists(): + lora_weights_path = latest_folder / "lora_weights.ckpt" + + if lora_weights_path.exists(): + if lora_weights_path.suffix == ".safetensors": + from safetensors.torch import load_file + state_dict = load_file(str(lora_weights_path)) + else: + ckpt = torch.load(lora_weights_path, map_location="cpu") + state_dict = ckpt.get("state_dict", ckpt) + + # Load only lora weights + unwrapped.load_state_dict(state_dict, strict=False) + print(f"Loaded LoRA weights from {lora_weights_path}") + else: + # Full finetune: load model.safetensors or pytorch_model.bin + model_path = latest_folder / "model.safetensors" + if not model_path.exists(): + model_path = latest_folder / "pytorch_model.bin" + + if model_path.exists(): + if model_path.suffix == ".safetensors": + from safetensors.torch import load_file + state_dict = load_file(str(model_path)) + else: + ckpt = torch.load(model_path, map_location="cpu") + state_dict = ckpt.get("state_dict", ckpt) + + unwrapped.load_state_dict(state_dict, strict=False) + print(f"Loaded model weights from {model_path}") + + # Load optimizer state + optimizer_path = latest_folder / "optimizer.pth" + if optimizer_path.exists(): + optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu")) + print(f"Loaded optimizer state from {optimizer_path}") + + # Load scheduler state + scheduler_path = latest_folder / "scheduler.pth" + if scheduler_path.exists(): + scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu")) + print(f"Loaded scheduler state from {scheduler_path}") + + # Try to infer step from checkpoint folders + step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")] + if step_folders: + steps = [int(d.name.split("_")[1]) for d in step_folders] + resume_step = max(steps) + print(f"Resuming from step {resume_step}") + return resume_step + + return 0 + + +def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False): """ Save checkpoint with different strategies for full finetune vs LoRA: - Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable) @@ -325,6 +440,17 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret save_file(state_dict, folder / "lora_weights.safetensors") else: torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt") + + # Save LoRA config and base model path to a separate JSON file + # If distribute=True, save hf_model_id; otherwise save local pretrained_path + import json + base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None) + lora_info = { + "base_model": base_model_to_save, + "lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg), + } + with open(folder / "lora_config.json", "w", encoding="utf-8") as f: + json.dump(lora_info, f, indent=2, ensure_ascii=False) else: # Full finetune: save non-vae weights to model.safetensors state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")} @@ -345,6 +471,29 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth") + # Update (or create) a `latest` symlink pointing to the most recent checkpoint folder + latest_link = save_dir / "latest" + try: + if latest_link.exists() or latest_link.is_symlink(): + # remove existing link or directory + if latest_link.is_dir() and not latest_link.is_symlink(): + shutil.rmtree(latest_link) + else: + latest_link.unlink() + # Create a symlink pointing to the new folder + os.symlink(str(folder), str(latest_link)) + except Exception: + # If symlink creation fails (e.g., on Windows or permission issues), fall back to copying + try: + if latest_link.exists(): + if latest_link.is_dir(): + shutil.rmtree(latest_link) + else: + latest_link.unlink() + shutil.copytree(folder, latest_link) + except Exception: + print(f"Warning: failed to update latest checkpoint link at {latest_link}") + if __name__ == "__main__": from voxcpm.training.config import load_yaml_config @@ -358,5 +507,4 @@ if __name__ == "__main__": else: # Otherwise use command line args (parsed by argbind) with argbind.scope(args): - train() - + train() \ No newline at end of file diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index a2c4290..f4ccc5b 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -55,11 +55,12 @@ class VoxCPM: self.denoiser = ZipEnhancer(zipenhancer_model_path) else: self.denoiser = None - print("Warm up VoxCPMModel...") - self.tts_model.generate( - target_text="Hello, this is the first test sentence.", - max_len=10, - ) + if optimize: + print("Warm up VoxCPMModel...") + self.tts_model.generate( + target_text="Hello, this is the first test sentence.", + max_len=10, + ) @classmethod def from_pretrained(cls,