Files
VoxCPM-use/lora_ft_webui.py

1253 lines
42 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import sys
import time
import glob
import json
import yaml
import shutil
import datetime
import subprocess
import threading
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
from typing import Optional, List
# Add src to sys.path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root / "src"))
# Default pretrained model path relative to this repo
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
import numpy as np
from funasr import AutoModel
# --- Localization ---
LANG_DICT = {
"en": {
"title": "VoxCPM LoRA WebUI",
"tab_train": "Training",
"tab_infer": "Inference",
"pretrained_path": "Pretrained Model Path",
"train_manifest": "Train Manifest (jsonl)",
"val_manifest": "Validation Manifest (Optional)",
"lr": "Learning Rate",
"max_iters": "Max Iterations",
"batch_size": "Batch Size",
"lora_rank": "LoRA Rank",
"lora_alpha": "LoRA Alpha",
"save_interval": "Save Interval",
"start_train": "Start Training",
"stop_train": "Stop Training",
"train_logs": "Training Logs",
"text_to_synth": "Text to Synthesize",
"voice_cloning": "### Voice Cloning (Optional)",
"ref_audio": "Reference Audio",
"ref_text": "Reference Text (Optional)",
"select_lora": "Select LoRA Checkpoint",
"cfg_scale": "CFG Scale",
"infer_steps": "Inference Steps",
"seed": "Seed",
"gen_audio": "Generate Audio",
"gen_output": "Generated Audio",
"status": "Status",
"lang_select": "Language / 语言",
"refresh": "Refresh",
"output_name": "Output Name (Optional, resume if exists)",
},
"zh": {
"title": "VoxCPM LoRA WebUI",
"tab_train": "训练 (Training)",
"tab_infer": "推理 (Inference)",
"pretrained_path": "预训练模型路径",
"train_manifest": "训练数据清单 (jsonl)",
"val_manifest": "验证数据清单 (可选)",
"lr": "学习率 (Learning Rate)",
"max_iters": "最大迭代次数",
"batch_size": "批次大小 (Batch Size)",
"lora_rank": "LoRA Rank",
"lora_alpha": "LoRA Alpha",
"save_interval": "保存间隔 (Steps)",
"start_train": "开始训练",
"stop_train": "停止训练",
"train_logs": "训练日志",
"text_to_synth": "合成文本",
"voice_cloning": "### 声音克隆 (可选)",
"ref_audio": "参考音频",
"ref_text": "参考文本 (可选)",
"select_lora": "选择 LoRA 模型",
"cfg_scale": "CFG Scale (引导系数)",
"infer_steps": "推理步数",
"seed": "随机种子 (Seed)",
"gen_audio": "生成音频",
"gen_output": "生成结果",
"status": "状态",
"lang_select": "Language / 语言",
"refresh": "刷新",
"output_name": "输出目录名称 (可选,若存在则继续训练)",
}
}
# Global variables
current_model: Optional[VoxCPM] = None
asr_model: Optional[AutoModel] = None
training_process: Optional[subprocess.Popen] = None
training_log = ""
def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def get_or_load_asr_model():
global asr_model
if asr_model is None:
print("Loading ASR model (SenseVoiceSmall)...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
asr_model = AutoModel(
model="iic/SenseVoiceSmall",
disable_update=True,
log_level='ERROR',
device=device,
)
return asr_model
def recognize_audio(audio_path):
if not audio_path:
return ""
try:
model = get_or_load_asr_model()
res = model.generate(input=audio_path, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
return text
except Exception as e:
print(f"ASR Error: {e}")
return ""
def scan_lora_checkpoints(root_dir="lora", with_info=False):
"""
Scans for LoRA checkpoints in the lora directory.
Args:
root_dir: Directory to scan for LoRA checkpoints
with_info: If True, returns list of (path, base_model) tuples
Returns:
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
"""
checkpoints = []
if not os.path.exists(root_dir):
os.makedirs(root_dir, exist_ok=True)
# Look for lora_weights.safetensors recursively
for root, dirs, files in os.walk(root_dir):
if "lora_weights.safetensors" in files:
# Use the relative path from root_dir as the ID
rel_path = os.path.relpath(root, root_dir)
if with_info:
# Try to read base_model from lora_config.json
base_model = None
lora_config_file = os.path.join(root, "lora_config.json")
if os.path.exists(lora_config_file):
try:
with open(lora_config_file, "r", encoding="utf-8") as f:
lora_info = json.load(f)
base_model = lora_info.get("base_model", "Unknown")
except:
pass
checkpoints.append((rel_path, base_model))
else:
checkpoints.append(rel_path)
# Also check for checkpoints in the default location if they exist
default_ckpt = "checkpoints/finetune_lora"
if os.path.exists(os.path.join(root_dir, default_ckpt)):
# This might be covered by the walk, but good to be sure
pass
return sorted(checkpoints, reverse=True)
def load_lora_config_from_checkpoint(lora_path):
"""Load LoRA config from lora_config.json if available."""
lora_config_file = os.path.join(lora_path, "lora_config.json")
if os.path.exists(lora_config_file):
try:
with open(lora_config_file, "r", encoding="utf-8") as f:
lora_info = json.load(f)
lora_cfg_dict = lora_info.get("lora_config", {})
if lora_cfg_dict:
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
except Exception as e:
print(f"Warning: Failed to load lora_config.json: {e}")
return None, None
def get_default_lora_config():
"""Return default LoRA config for hot-swapping support."""
return LoRAConfig(
enable_lm=True,
enable_dit=True,
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
)
def load_model(pretrained_path, lora_path=None):
global current_model
print(f"Loading model from {pretrained_path}...")
lora_config = None
lora_weights_path = None
if lora_path:
full_lora_path = os.path.join("lora", lora_path)
if os.path.exists(full_lora_path):
lora_weights_path = full_lora_path
# Try to load LoRA config from lora_config.json
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
if lora_config:
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json")
else:
# Fallback to default config for old checkpoints
lora_config = get_default_lora_config()
print("Using default LoRA config (lora_config.json not found)")
# Always init with a default LoRA config to allow hot-swapping later
if lora_config is None:
lora_config = get_default_lora_config()
current_model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=False,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
return "Model loaded successfully!"
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed):
global current_model
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
if current_model is None:
base_model_path = default_pretrained_path # 默认路径
# 如果选择了 LoRA尝试从其 config 读取 base_model
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
if os.path.exists(lora_config_file):
try:
with open(lora_config_file, "r", encoding="utf-8") as f:
lora_info = json.load(f)
saved_base_model = lora_info.get("base_model")
if saved_base_model:
# 优先使用保存的 base_model 路径
if os.path.exists(saved_base_model):
base_model_path = saved_base_model
print(f"Using base model from LoRA config: {base_model_path}")
else:
print(f"Warning: Saved base_model path not found: {saved_base_model}")
print(f"Falling back to default: {base_model_path}")
except Exception as e:
print(f"Warning: Failed to read base_model from LoRA config: {e}")
# 加载模型
try:
print(f"Loading base model: {base_model_path}")
status_msg = load_model(base_model_path)
if lora_selection and lora_selection != "None":
print(f"Model loaded for LoRA: {lora_selection}")
except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg)
return None, error_msg
# Handle LoRA hot-swapping
if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection)
print(f"Hot-loading LoRA: {full_lora_path}")
try:
current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True)
except Exception as e:
print(f"Error loading LoRA: {e}")
return None, f"Error loading LoRA: {e}"
else:
print("Disabling LoRA")
current_model.set_lora_enabled(False)
if seed != -1:
torch.manual_seed(seed)
np.random.seed(seed)
# 处理 prompt 参数:必须同时为 None 或同时有值
final_prompt_wav = None
final_prompt_text = None
if prompt_wav and prompt_wav.strip():
# 有参考音频
final_prompt_wav = prompt_wav
# 如果没有提供参考文本,尝试自动识别
if not prompt_text or not prompt_text.strip():
print("参考音频已提供但缺少文本,自动识别中...")
try:
final_prompt_text = recognize_audio(prompt_wav)
if final_prompt_text:
print(f"自动识别文本: {final_prompt_text}")
else:
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
except Exception as e:
return None, f"错误:自动识别参考音频失败 - {str(e)}"
else:
final_prompt_text = prompt_text.strip()
# 如果没有参考音频,两个都设为 None用于零样本 TTS
try:
audio_np = current_model.generate(
text=text,
prompt_wav_path=final_prompt_wav,
prompt_text=final_prompt_text,
cfg_value=cfg_scale,
inference_timesteps=steps,
denoise=False
)
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"Error: {str(e)}"
def start_training(
pretrained_path,
train_manifest,
val_manifest,
learning_rate,
num_iters,
batch_size,
lora_rank,
lora_alpha,
save_interval,
output_name="",
# Advanced options
grad_accum_steps=1,
num_workers=2,
log_interval=10,
valid_interval=1000,
weight_decay=0.01,
warmup_steps=100,
max_steps=None,
sample_rate=44100,
# LoRA advanced
enable_lm=True,
enable_dit=True,
enable_proj=False,
dropout=0.0,
tensorboard_path="",
# Distribution options
hf_model_id="",
distribute=False,
):
global training_process, training_log
if training_process is not None and training_process.poll() is None:
return "Training is already running!"
if output_name and output_name.strip():
timestamp = output_name.strip()
else:
timestamp = get_timestamp_str()
save_dir = os.path.join("lora", timestamp)
checkpoints_dir = os.path.join(save_dir, "checkpoints")
logs_dir = os.path.join(save_dir, "logs")
os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
# Create config dictionary
# Resolve max_steps default
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
config = {
"pretrained_path": pretrained_path,
"train_manifest": train_manifest,
"val_manifest": val_manifest,
"sample_rate": int(sample_rate),
"batch_size": int(batch_size),
"grad_accum_steps": int(grad_accum_steps),
"num_workers": int(num_workers),
"num_iters": int(num_iters),
"log_interval": int(log_interval),
"valid_interval": int(valid_interval),
"save_interval": int(save_interval),
"learning_rate": float(learning_rate),
"weight_decay": float(weight_decay),
"warmup_steps": int(warmup_steps),
"max_steps": resolved_max_steps,
"save_path": checkpoints_dir,
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
"lambdas": {
"loss/diff": 1.0,
"loss/stop": 1.0
},
"lora": {
"enable_lm": bool(enable_lm),
"enable_dit": bool(enable_dit),
"enable_proj": bool(enable_proj),
"r": int(lora_rank),
"alpha": int(lora_alpha),
"dropout": float(dropout),
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
},
}
# Add distribution options if provided
if hf_model_id and hf_model_id.strip():
config["hf_model_id"] = hf_model_id.strip()
if distribute:
config["distribute"] = True
config_path = os.path.join(save_dir, "train_config.yaml")
with open(config_path, "w") as f:
yaml.dump(config, f)
cmd = [
sys.executable,
"scripts/train_voxcpm_finetune.py",
"--config_path",
config_path
]
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
def run_process():
global training_process, training_log
training_process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
for line in training_process.stdout:
training_log += line
# Keep log size manageable
if len(training_log) > 100000:
training_log = training_log[-100000:]
training_process.wait()
training_log += f"\nTraining finished with code {training_process.returncode}"
threading.Thread(target=run_process, daemon=True).start()
return f"Training started! Check 'lora/{timestamp}'"
def get_training_log():
return training_log
def stop_training():
global training_process, training_log
if training_process is not None and training_process.poll() is None:
training_process.terminate()
training_log += "\nTraining terminated by user."
return "Training stopped."
return "No training running."
# --- GUI Layout ---
# 自定义CSS样式
custom_css = """
/* 整体主题样式 */
.gradio-container {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* 标题区域样式 - 扁平化设计 */
.title-section {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 8px;
padding: 15px 25px;
margin-bottom: 15px;
border: none;
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
}
.title-section h1 {
color: white;
text-shadow: none;
font-weight: 600;
margin: 0;
font-size: 28px;
line-height: 1.2;
}
.title-section h3 {
color: rgba(255, 255, 255, 0.9);
font-weight: 400;
margin-top: 5px;
font-size: 14px;
line-height: 1.3;
}
.title-section p {
color: rgba(255, 255, 255, 0.85);
font-size: 13px;
margin: 5px 0 0 0;
line-height: 1.3;
}
/* 标签页样式 */
.tabs {
background: white;
border-radius: 15px;
padding: 10px;
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
}
/* 按钮样式增强 */
.button-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
border-radius: 12px;
padding: 12px 30px;
font-weight: 600;
color: white;
transition: all 0.3s ease;
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
}
.button-primary:hover {
transform: translateY(-2px);
box-shadow: 0 6px 25px rgba(102, 126, 234, 0.4);
}
.button-stop {
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
border: none;
border-radius: 12px;
padding: 12px 30px;
font-weight: 600;
color: white;
transition: all 0.3s ease;
box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3);
}
.button-stop:hover {
transform: translateY(-2px);
box-shadow: 0 6px 25px rgba(250, 112, 154, 0.4);
}
.button-refresh {
background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%);
border: none;
border-radius: 10px;
padding: 8px 20px;
font-weight: 500;
color: white;
transition: all 0.3s ease;
box-shadow: 0 2px 10px rgba(132, 250, 176, 0.3);
}
.button-refresh:hover {
transform: translateY(-1px);
box-shadow: 0 4px 15px rgba(132, 250, 176, 0.4);
}
/* 表单区域样式 */
.form-section {
background: white;
border-radius: 20px;
padding: 30px;
margin: 15px 0;
box-shadow: 0 8px 30px rgba(0,0,0,0.08);
border: 1px solid rgba(0,0,0,0.05);
}
/* 输入框样式 */
.input-field {
border-radius: 12px;
border: 2px solid #e0e0e0;
padding: 12px 16px;
transition: all 0.3s ease;
background: #fafafa;
}
.input-field:focus {
border-color: #667eea;
box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.1);
background: white;
}
/* 滑块样式 */
.slider {
-webkit-appearance: none;
appearance: none;
width: 100%;
height: 6px;
border-radius: 3px;
background: linear-gradient(90deg, #667eea, #764ba2);
outline: none;
opacity: 0.8;
transition: opacity 0.2s;
}
.slider:hover {
opacity: 1;
}
.slider::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 18px;
height: 18px;
border-radius: 50%;
background: white;
cursor: pointer;
border: 3px solid #667eea;
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
}
.slider::-moz-range-thumb {
width: 18px;
height: 18px;
border-radius: 50%;
background: white;
cursor: pointer;
border: 3px solid #667eea;
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
}
/* 折叠面板样式 */
.accordion {
border-radius: 12px;
border: 2px solid #e0e0e0;
overflow: hidden;
background: white;
}
.accordion-header {
background: linear-gradient(135deg, #f5f7fa 0%, #e3e7ed 100%);
padding: 15px 20px;
font-weight: 600;
color: #333;
}
/* 状态显示样式 */
.status-success {
background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%);
color: white;
padding: 12px 20px;
border-radius: 12px;
font-weight: 500;
box-shadow: 0 4px 15px rgba(132, 250, 176, 0.3);
}
.status-error {
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
color: white;
padding: 12px 20px;
border-radius: 12px;
font-weight: 500;
box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3);
}
/* 语言切换按钮样式 - 扁平化 */
.lang-selector {
background: rgba(255, 255, 255, 0.25);
backdrop-filter: blur(10px);
border-radius: 8px;
padding: 8px 12px;
border: 1px solid rgba(255, 255, 255, 0.4);
}
.lang-selector label.gr-box {
color: white !important;
font-weight: 600;
margin-bottom: 8px !important;
}
/* 单选按钮组样式 */
.lang-selector fieldset,
.lang-selector .gr-form {
gap: 10px !important;
display: flex !important;
}
/* 单选按钮容器 - 扁平化 (未选中状态 - 较浅的深色) */
.lang-selector label.gr-radio-label {
background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important;
border: 1px solid rgba(255, 255, 255, 0.5) !important;
border-radius: 6px !important;
padding: 8px 18px !important;
color: white !important;
font-weight: 500 !important;
transition: all 0.2s ease !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
cursor: pointer !important;
margin: 0 4px !important;
}
/* 选中的单选按钮 - 扁平化 (更深的深色背景) */
.lang-selector input[type="radio"]:checked + label,
.lang-selector label.gr-radio-label:has(input:checked) {
background: linear-gradient(135deg, #5568d3, #6b4c9a) !important;
color: white !important;
border: 1px solid rgba(255, 255, 255, 0.6) !important;
font-weight: 600 !important;
box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important;
transform: none !important;
}
/* 未选中的单选按钮悬停效果 - 扁平化 */
.lang-selector label.gr-radio-label:hover {
background: linear-gradient(135deg, rgba(102, 126, 234, 0.75), rgba(118, 75, 162, 0.75)) !important;
border-color: rgba(255, 255, 255, 0.7) !important;
transform: none !important;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.15) !important;
}
/* 隐藏原始的单选按钮圆点 */
.lang-selector input[type="radio"] {
opacity: 0;
position: absolute;
}
/* Gradio Radio 特定样式 - 扁平化 */
.lang-selector .wrap {
gap: 8px !important;
}
.lang-selector .wrap > label {
background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important;
border: 1px solid rgba(255, 255, 255, 0.5) !important;
border-radius: 6px !important;
padding: 8px 18px !important;
color: white !important;
font-weight: 500 !important;
transition: all 0.2s ease !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
}
.lang-selector .wrap > label.selected {
background: linear-gradient(135deg, #5568d3, #6b4c9a) !important;
color: white !important;
border: 1px solid rgba(255, 255, 255, 0.6) !important;
font-weight: 600 !important;
box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important;
}
/* 标签样式优化 */
label {
color: #333;
font-weight: 500;
margin-bottom: 8px;
}
/* Markdown 标题样式 */
.markdown-text h4 {
color: #667eea;
font-weight: 600;
margin-top: 15px;
margin-bottom: 10px;
}
/* 参数组件间距优化 */
.form-section > div {
margin-bottom: 15px;
}
/* Slider 组件样式优化 */
.gr-slider {
padding: 10px 0;
}
/* Number 输入框优化 */
.gr-number {
max-width: 100%;
}
/* 按钮容器优化 */
.gr-button {
min-height: 45px;
font-size: 16px;
}
/* 三栏布局优化 */
#component-0 .gr-row {
gap: 20px;
}
/* 生成按钮特殊样式 */
.button-primary.gr-button-lg {
min-height: 55px;
font-size: 18px;
font-weight: 700;
margin-top: 20px;
margin-bottom: 10px;
}
/* 刷新按钮小尺寸 */
.button-refresh.gr-button-sm {
min-height: 38px;
font-size: 14px;
margin-top: 5px;
margin-bottom: 15px;
}
/* 信息提示文字样式 */
.gr-info {
font-size: 13px;
color: #666;
margin-top: 5px;
}
/* 区域标题样式优化 */
.form-section h4 {
color: #667eea;
font-weight: 600;
margin-top: 0;
margin-bottom: 15px;
padding-bottom: 10px;
border-bottom: 2px solid #f0f0f0;
}
.form-section strong {
color: #667eea;
font-size: 15px;
display: block;
margin: 15px 0 10px 0;
}
"""
with gr.Blocks(
title="VoxCPM LoRA WebUI",
theme=gr.themes.Soft(),
css=custom_css
) as app:
# State for language
lang_state = gr.State("zh") # Default to Chinese
# 标题区域
with gr.Row(elem_classes="title-section"):
with gr.Column(scale=3):
title_md = gr.Markdown("""
# 🎵 VoxCPM LoRA WebUI
### 强大的语音合成和 LoRA 微调工具
支持语音克隆、LoRA 模型训练和推理的完整解决方案
""")
with gr.Column(scale=1):
lang_btn = gr.Radio(
choices=["en", "zh"],
value="zh",
label="🌐 Language / 语言",
elem_classes="lang-selector"
)
with gr.Tabs(elem_classes="tabs") as tabs:
# === Training Tab ===
with gr.Tab("🚀 训练 (Training)") as tab_train:
gr.Markdown("""
### 🎯 模型训练设置
配置你的 LoRA 微调训练参数
""")
with gr.Row():
with gr.Column(scale=2, elem_classes="form-section"):
gr.Markdown("#### 📁 基础配置")
train_pretrained_path = gr.Textbox(
label="📂 预训练模型路径",
value=default_pretrained_path,
elem_classes="input-field"
)
train_manifest = gr.Textbox(
label="📋 训练数据清单 (jsonl)",
value="examples/train_data_example.jsonl",
elem_classes="input-field"
)
val_manifest = gr.Textbox(
label="📊 验证数据清单 (可选)",
value="",
elem_classes="input-field"
)
gr.Markdown("#### ⚙️ 训练参数")
with gr.Row():
lr = gr.Number(
label="📈 学习率 (Learning Rate)",
value=1e-4,
elem_classes="input-field"
)
num_iters = gr.Number(
label="🔄 最大迭代次数",
value=2000,
precision=0,
elem_classes="input-field"
)
batch_size = gr.Number(
label="📦 批次大小 (Batch Size)",
value=1,
precision=0,
elem_classes="input-field"
)
with gr.Row():
lora_rank = gr.Number(
label="🎯 LoRA Rank",
value=32,
precision=0,
elem_classes="input-field"
)
lora_alpha = gr.Number(
label="⚖️ LoRA Alpha",
value=16,
precision=0,
elem_classes="input-field"
)
save_interval = gr.Number(
label="💾 保存间隔 (Steps)",
value=1000,
precision=0,
elem_classes="input-field"
)
output_name = gr.Textbox(
label="📁 输出目录名称 (可选,若存在则继续训练)",
value="",
elem_classes="input-field"
)
with gr.Row():
start_btn = gr.Button(
"▶️ 开始训练",
variant="primary",
elem_classes="button-primary"
)
stop_btn = gr.Button(
"⏹️ 停止训练",
variant="stop",
elem_classes="button-stop"
)
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
with gr.Row():
grad_accum_steps = gr.Number(label="梯度累积 (grad_accum_steps)", value=1, precision=0)
num_workers = gr.Number(label="数据加载线程 (num_workers)", value=2, precision=0)
log_interval = gr.Number(label="日志间隔 (log_interval)", value=10, precision=0)
with gr.Row():
valid_interval = gr.Number(label="验证间隔 (valid_interval)", value=1000, precision=0)
weight_decay = gr.Number(label="权重衰减 (weight_decay)", value=0.01)
warmup_steps = gr.Number(label="warmup_steps", value=100, precision=0)
with gr.Row():
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
with gr.Row():
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
dropout = gr.Number(label="LoRA Dropout", value=0.0)
gr.Markdown("#### 分发选项 (Distribution)")
with gr.Row():
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
with gr.Column(scale=2, elem_classes="form-section"):
gr.Markdown("#### 📊 训练日志")
logs_out = gr.TextArea(
label="",
lines=20,
max_lines=30,
interactive=False,
elem_classes="input-field",
show_label=False
)
start_btn.click(
start_training,
inputs=[
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
output_name,
# advanced
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
# distribution
hf_model_id, distribute
],
outputs=[logs_out] # Initial message
)
stop_btn.click(stop_training, outputs=[logs_out])
# Log refresher
timer = gr.Timer(1)
timer.tick(get_training_log, outputs=logs_out)
# === Inference Tab ===
with gr.Tab("🎵 推理 (Inference)") as tab_infer:
gr.Markdown("""
### 🎤 语音合成
使用训练好的 LoRA 模型生成语音,支持 LoRA 微调和声音克隆
""")
with gr.Row():
# 左栏:输入配置 (35%)
with gr.Column(scale=35, elem_classes="form-section"):
gr.Markdown("#### 📝 输入配置")
infer_text = gr.TextArea(
label="💬 合成文本",
value="Hello, this is a test of the VoxCPM LoRA model.",
elem_classes="input-field",
lines=4,
placeholder="输入要合成的文本内容..."
)
gr.Markdown("**🎭 声音克隆(可选)**")
prompt_wav = gr.Audio(
label="🎵 参考音频",
type="filepath",
elem_classes="input-field"
)
prompt_text = gr.Textbox(
label="📝 参考文本(可选)",
elem_classes="input-field",
placeholder="如不填写,将自动识别参考音频内容"
)
# 中栏:模型选择和参数配置 (35%)
with gr.Column(scale=35, elem_classes="form-section"):
gr.Markdown("#### 🤖 模型选择")
lora_select = gr.Dropdown(
label="🎯 LoRA 模型",
choices=["None"] + scan_lora_checkpoints(),
value="None",
interactive=True,
elem_classes="input-field",
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
)
refresh_lora_btn = gr.Button(
"🔄 刷新模型列表",
elem_classes="button-refresh",
size="sm"
)
gr.Markdown("#### ⚙️ 生成参数")
cfg_scale = gr.Slider(
label="🎛️ CFG Scale",
minimum=1.0,
maximum=5.0,
value=2.0,
step=0.1,
info="引导系数,值越大越贴近提示"
)
steps = gr.Slider(
label="🔢 推理步数",
minimum=1,
maximum=50,
value=10,
step=1,
info="生成质量与步数成正比,但耗时更长"
)
seed = gr.Number(
label="🎲 随机种子",
value=-1,
precision=0,
elem_classes="input-field",
info="-1 为随机,固定值可复现结果"
)
generate_btn = gr.Button(
"🎵 生成音频",
variant="primary",
elem_classes="button-primary",
size="lg"
)
# 右栏:生成结果 (30%)
with gr.Column(scale=30, elem_classes="form-section"):
gr.Markdown("#### 🎧 生成结果")
audio_out = gr.Audio(
label="",
elem_classes="input-field",
show_label=False
)
gr.Markdown("#### 📋 状态信息")
status_out = gr.Textbox(
label="",
interactive=False,
elem_classes="input-field",
show_label=False,
lines=3,
placeholder="等待生成..."
)
def refresh_loras():
# 获取 LoRA checkpoints 及其 base model 信息
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
# 输出调试信息
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点")
for ckpt_path, base_model in checkpoints_with_info:
if base_model:
print(f" - {ckpt_path} (Base Model: {base_model})")
else:
print(f" - {ckpt_path}")
return gr.update(choices=choices, value="None")
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
# Auto-recognize audio when uploaded
prompt_wav.change(
fn=recognize_audio,
inputs=[prompt_wav],
outputs=[prompt_text]
)
generate_btn.click(
run_inference,
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed],
outputs=[audio_out, status_out]
)
# --- Language Switching Logic ---
def change_language(lang):
d = LANG_DICT[lang]
# Labels for advanced options
if lang == "zh":
adv = {
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
'num_workers': "数据加载线程 (num_workers)",
'log_interval': "日志间隔 (log_interval)",
'valid_interval': "验证间隔 (valid_interval)",
'weight_decay': "权重衰减 (weight_decay)",
'warmup_steps': "warmup_steps",
'max_steps': "最大步数 (max_steps)",
'sample_rate': "采样率 (sample_rate)",
'enable_lm': "启用 LoRA LM (enable_lm)",
'enable_dit': "启用 LoRA DIT (enable_dit)",
'enable_proj': "启用投影 (enable_proj)",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard 路径 (可选)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "分发模式 (distribute)",
}
else:
adv = {
'grad_accum_steps': "Grad Accum Steps",
'num_workers': "Num Workers",
'log_interval': "Log Interval",
'valid_interval': "Valid Interval",
'weight_decay': "Weight Decay",
'warmup_steps': "Warmup Steps",
'max_steps': "Max Steps",
'sample_rate': "Sample Rate",
'enable_lm': "Enable LoRA LM",
'enable_dit': "Enable LoRA DIT",
'enable_proj': "Enable Projection",
'dropout': "LoRA Dropout",
'tensorboard_path': "Tensorboard Path (Optional)",
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
'distribute': "Distribute Mode",
}
return (
gr.update(value=f"# {d['title']}"),
gr.update(label=d['tab_train']),
gr.update(label=d['tab_infer']),
gr.update(label=d['pretrained_path']),
gr.update(label=d['train_manifest']),
gr.update(label=d['val_manifest']),
gr.update(label=d['lr']),
gr.update(label=d['max_iters']),
gr.update(label=d['batch_size']),
gr.update(label=d['lora_rank']),
gr.update(label=d['lora_alpha']),
gr.update(label=d['save_interval']),
gr.update(label=d['output_name']),
gr.update(value=d['start_train']),
gr.update(value=d['stop_train']),
gr.update(label=d['train_logs']),
# Advanced options (must match outputs order)
gr.update(label=adv['grad_accum_steps']),
gr.update(label=adv['num_workers']),
gr.update(label=adv['log_interval']),
gr.update(label=adv['valid_interval']),
gr.update(label=adv['weight_decay']),
gr.update(label=adv['warmup_steps']),
gr.update(label=adv['max_steps']),
gr.update(label=adv['sample_rate']),
gr.update(label=adv['enable_lm']),
gr.update(label=adv['enable_dit']),
gr.update(label=adv['enable_proj']),
gr.update(label=adv['dropout']),
gr.update(label=adv['tensorboard_path']),
# Distribution options
gr.update(label=adv['hf_model_id']),
gr.update(label=adv['distribute']),
# Inference section
gr.update(label=d['text_to_synth']),
gr.update(label=d['ref_audio']),
gr.update(label=d['ref_text']),
gr.update(label=d['select_lora']),
gr.update(value=d['refresh']),
gr.update(label=d['cfg_scale']),
gr.update(label=d['infer_steps']),
gr.update(label=d['seed']),
gr.update(value=d['gen_audio']),
gr.update(label=d['gen_output']),
gr.update(label=d['status']),
)
lang_btn.change(
change_language,
inputs=[lang_btn],
outputs=[
title_md, tab_train, tab_infer,
train_pretrained_path, train_manifest, val_manifest,
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
output_name,
start_btn, stop_btn, logs_out,
# advanced outputs
grad_accum_steps, num_workers, log_interval, valid_interval,
weight_decay, warmup_steps, max_steps, sample_rate,
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
# distribution outputs
hf_model_id, distribute,
infer_text, prompt_wav, prompt_text,
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
generate_btn, audio_out, status_out
]
)
if __name__ == "__main__":
# Ensure lora directory exists
os.makedirs("lora", exist_ok=True)
app.queue().launch(server_name="0.0.0.0", server_port=7860)