1253 lines
42 KiB
Python
1253 lines
42 KiB
Python
import os
|
||
import sys
|
||
import time
|
||
import glob
|
||
import json
|
||
import yaml
|
||
import shutil
|
||
import datetime
|
||
import subprocess
|
||
import threading
|
||
import gradio as gr
|
||
import torch
|
||
import soundfile as sf
|
||
from pathlib import Path
|
||
from typing import Optional, List
|
||
|
||
# Add src to sys.path
|
||
project_root = Path(__file__).parent
|
||
sys.path.insert(0, str(project_root / "src"))
|
||
|
||
# Default pretrained model path relative to this repo
|
||
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
|
||
|
||
from voxcpm.core import VoxCPM
|
||
from voxcpm.model.voxcpm import LoRAConfig
|
||
import numpy as np
|
||
from funasr import AutoModel
|
||
|
||
# --- Localization ---
|
||
LANG_DICT = {
|
||
"en": {
|
||
"title": "VoxCPM LoRA WebUI",
|
||
"tab_train": "Training",
|
||
"tab_infer": "Inference",
|
||
"pretrained_path": "Pretrained Model Path",
|
||
"train_manifest": "Train Manifest (jsonl)",
|
||
"val_manifest": "Validation Manifest (Optional)",
|
||
"lr": "Learning Rate",
|
||
"max_iters": "Max Iterations",
|
||
"batch_size": "Batch Size",
|
||
"lora_rank": "LoRA Rank",
|
||
"lora_alpha": "LoRA Alpha",
|
||
"save_interval": "Save Interval",
|
||
"start_train": "Start Training",
|
||
"stop_train": "Stop Training",
|
||
"train_logs": "Training Logs",
|
||
"text_to_synth": "Text to Synthesize",
|
||
"voice_cloning": "### Voice Cloning (Optional)",
|
||
"ref_audio": "Reference Audio",
|
||
"ref_text": "Reference Text (Optional)",
|
||
"select_lora": "Select LoRA Checkpoint",
|
||
"cfg_scale": "CFG Scale",
|
||
"infer_steps": "Inference Steps",
|
||
"seed": "Seed",
|
||
"gen_audio": "Generate Audio",
|
||
"gen_output": "Generated Audio",
|
||
"status": "Status",
|
||
"lang_select": "Language / 语言",
|
||
"refresh": "Refresh",
|
||
"output_name": "Output Name (Optional, resume if exists)",
|
||
},
|
||
"zh": {
|
||
"title": "VoxCPM LoRA WebUI",
|
||
"tab_train": "训练 (Training)",
|
||
"tab_infer": "推理 (Inference)",
|
||
"pretrained_path": "预训练模型路径",
|
||
"train_manifest": "训练数据清单 (jsonl)",
|
||
"val_manifest": "验证数据清单 (可选)",
|
||
"lr": "学习率 (Learning Rate)",
|
||
"max_iters": "最大迭代次数",
|
||
"batch_size": "批次大小 (Batch Size)",
|
||
"lora_rank": "LoRA Rank",
|
||
"lora_alpha": "LoRA Alpha",
|
||
"save_interval": "保存间隔 (Steps)",
|
||
"start_train": "开始训练",
|
||
"stop_train": "停止训练",
|
||
"train_logs": "训练日志",
|
||
"text_to_synth": "合成文本",
|
||
"voice_cloning": "### 声音克隆 (可选)",
|
||
"ref_audio": "参考音频",
|
||
"ref_text": "参考文本 (可选)",
|
||
"select_lora": "选择 LoRA 模型",
|
||
"cfg_scale": "CFG Scale (引导系数)",
|
||
"infer_steps": "推理步数",
|
||
"seed": "随机种子 (Seed)",
|
||
"gen_audio": "生成音频",
|
||
"gen_output": "生成结果",
|
||
"status": "状态",
|
||
"lang_select": "Language / 语言",
|
||
"refresh": "刷新",
|
||
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
||
}
|
||
}
|
||
|
||
# Global variables
|
||
current_model: Optional[VoxCPM] = None
|
||
asr_model: Optional[AutoModel] = None
|
||
training_process: Optional[subprocess.Popen] = None
|
||
training_log = ""
|
||
|
||
def get_timestamp_str():
|
||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
def get_or_load_asr_model():
|
||
global asr_model
|
||
if asr_model is None:
|
||
print("Loading ASR model (SenseVoiceSmall)...")
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||
asr_model = AutoModel(
|
||
model="iic/SenseVoiceSmall",
|
||
disable_update=True,
|
||
log_level='ERROR',
|
||
device=device,
|
||
)
|
||
return asr_model
|
||
|
||
def recognize_audio(audio_path):
|
||
if not audio_path:
|
||
return ""
|
||
try:
|
||
model = get_or_load_asr_model()
|
||
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
||
text = res[0]["text"].split('|>')[-1]
|
||
return text
|
||
except Exception as e:
|
||
print(f"ASR Error: {e}")
|
||
return ""
|
||
|
||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||
"""
|
||
Scans for LoRA checkpoints in the lora directory.
|
||
|
||
Args:
|
||
root_dir: Directory to scan for LoRA checkpoints
|
||
with_info: If True, returns list of (path, base_model) tuples
|
||
|
||
Returns:
|
||
List of checkpoint paths, or list of (path, base_model) tuples if with_info=True
|
||
"""
|
||
checkpoints = []
|
||
if not os.path.exists(root_dir):
|
||
os.makedirs(root_dir, exist_ok=True)
|
||
|
||
# Look for lora_weights.safetensors recursively
|
||
for root, dirs, files in os.walk(root_dir):
|
||
if "lora_weights.safetensors" in files:
|
||
# Use the relative path from root_dir as the ID
|
||
rel_path = os.path.relpath(root, root_dir)
|
||
|
||
if with_info:
|
||
# Try to read base_model from lora_config.json
|
||
base_model = None
|
||
lora_config_file = os.path.join(root, "lora_config.json")
|
||
if os.path.exists(lora_config_file):
|
||
try:
|
||
with open(lora_config_file, "r", encoding="utf-8") as f:
|
||
lora_info = json.load(f)
|
||
base_model = lora_info.get("base_model", "Unknown")
|
||
except:
|
||
pass
|
||
checkpoints.append((rel_path, base_model))
|
||
else:
|
||
checkpoints.append(rel_path)
|
||
|
||
# Also check for checkpoints in the default location if they exist
|
||
default_ckpt = "checkpoints/finetune_lora"
|
||
if os.path.exists(os.path.join(root_dir, default_ckpt)):
|
||
# This might be covered by the walk, but good to be sure
|
||
pass
|
||
|
||
return sorted(checkpoints, reverse=True)
|
||
|
||
def load_lora_config_from_checkpoint(lora_path):
|
||
"""Load LoRA config from lora_config.json if available."""
|
||
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
||
if os.path.exists(lora_config_file):
|
||
try:
|
||
with open(lora_config_file, "r", encoding="utf-8") as f:
|
||
lora_info = json.load(f)
|
||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||
if lora_cfg_dict:
|
||
return LoRAConfig(**lora_cfg_dict), lora_info.get("base_model")
|
||
except Exception as e:
|
||
print(f"Warning: Failed to load lora_config.json: {e}")
|
||
return None, None
|
||
|
||
def get_default_lora_config():
|
||
"""Return default LoRA config for hot-swapping support."""
|
||
return LoRAConfig(
|
||
enable_lm=True,
|
||
enable_dit=True,
|
||
r=32,
|
||
alpha=16,
|
||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
|
||
)
|
||
|
||
def load_model(pretrained_path, lora_path=None):
|
||
global current_model
|
||
print(f"Loading model from {pretrained_path}...")
|
||
|
||
lora_config = None
|
||
lora_weights_path = None
|
||
|
||
if lora_path:
|
||
full_lora_path = os.path.join("lora", lora_path)
|
||
if os.path.exists(full_lora_path):
|
||
lora_weights_path = full_lora_path
|
||
# Try to load LoRA config from lora_config.json
|
||
lora_config, _ = load_lora_config_from_checkpoint(full_lora_path)
|
||
if lora_config:
|
||
print(f"Loaded LoRA config from {full_lora_path}/lora_config.json")
|
||
else:
|
||
# Fallback to default config for old checkpoints
|
||
lora_config = get_default_lora_config()
|
||
print("Using default LoRA config (lora_config.json not found)")
|
||
|
||
# Always init with a default LoRA config to allow hot-swapping later
|
||
if lora_config is None:
|
||
lora_config = get_default_lora_config()
|
||
|
||
current_model = VoxCPM.from_pretrained(
|
||
hf_model_id=pretrained_path,
|
||
load_denoiser=False,
|
||
optimize=False,
|
||
lora_config=lora_config,
|
||
lora_weights_path=lora_weights_path,
|
||
)
|
||
return "Model loaded successfully!"
|
||
|
||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed):
|
||
global current_model
|
||
|
||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||
if current_model is None:
|
||
base_model_path = default_pretrained_path # 默认路径
|
||
|
||
# 如果选择了 LoRA,尝试从其 config 读取 base_model
|
||
if lora_selection and lora_selection != "None":
|
||
full_lora_path = os.path.join("lora", lora_selection)
|
||
lora_config_file = os.path.join(full_lora_path, "lora_config.json")
|
||
|
||
if os.path.exists(lora_config_file):
|
||
try:
|
||
with open(lora_config_file, "r", encoding="utf-8") as f:
|
||
lora_info = json.load(f)
|
||
saved_base_model = lora_info.get("base_model")
|
||
|
||
if saved_base_model:
|
||
# 优先使用保存的 base_model 路径
|
||
if os.path.exists(saved_base_model):
|
||
base_model_path = saved_base_model
|
||
print(f"Using base model from LoRA config: {base_model_path}")
|
||
else:
|
||
print(f"Warning: Saved base_model path not found: {saved_base_model}")
|
||
print(f"Falling back to default: {base_model_path}")
|
||
except Exception as e:
|
||
print(f"Warning: Failed to read base_model from LoRA config: {e}")
|
||
|
||
# 加载模型
|
||
try:
|
||
print(f"Loading base model: {base_model_path}")
|
||
status_msg = load_model(base_model_path)
|
||
if lora_selection and lora_selection != "None":
|
||
print(f"Model loaded for LoRA: {lora_selection}")
|
||
except Exception as e:
|
||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||
print(error_msg)
|
||
return None, error_msg
|
||
|
||
# Handle LoRA hot-swapping
|
||
if lora_selection and lora_selection != "None":
|
||
full_lora_path = os.path.join("lora", lora_selection)
|
||
print(f"Hot-loading LoRA: {full_lora_path}")
|
||
try:
|
||
current_model.load_lora(full_lora_path)
|
||
current_model.set_lora_enabled(True)
|
||
except Exception as e:
|
||
print(f"Error loading LoRA: {e}")
|
||
return None, f"Error loading LoRA: {e}"
|
||
else:
|
||
print("Disabling LoRA")
|
||
current_model.set_lora_enabled(False)
|
||
|
||
if seed != -1:
|
||
torch.manual_seed(seed)
|
||
np.random.seed(seed)
|
||
|
||
# 处理 prompt 参数:必须同时为 None 或同时有值
|
||
final_prompt_wav = None
|
||
final_prompt_text = None
|
||
|
||
if prompt_wav and prompt_wav.strip():
|
||
# 有参考音频
|
||
final_prompt_wav = prompt_wav
|
||
|
||
# 如果没有提供参考文本,尝试自动识别
|
||
if not prompt_text or not prompt_text.strip():
|
||
print("参考音频已提供但缺少文本,自动识别中...")
|
||
try:
|
||
final_prompt_text = recognize_audio(prompt_wav)
|
||
if final_prompt_text:
|
||
print(f"自动识别文本: {final_prompt_text}")
|
||
else:
|
||
return None, "错误:无法识别参考音频内容,请手动填写参考文本"
|
||
except Exception as e:
|
||
return None, f"错误:自动识别参考音频失败 - {str(e)}"
|
||
else:
|
||
final_prompt_text = prompt_text.strip()
|
||
# 如果没有参考音频,两个都设为 None(用于零样本 TTS)
|
||
|
||
try:
|
||
audio_np = current_model.generate(
|
||
text=text,
|
||
prompt_wav_path=final_prompt_wav,
|
||
prompt_text=final_prompt_text,
|
||
cfg_value=cfg_scale,
|
||
inference_timesteps=steps,
|
||
denoise=False
|
||
)
|
||
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None, f"Error: {str(e)}"
|
||
|
||
def start_training(
|
||
pretrained_path,
|
||
train_manifest,
|
||
val_manifest,
|
||
learning_rate,
|
||
num_iters,
|
||
batch_size,
|
||
lora_rank,
|
||
lora_alpha,
|
||
save_interval,
|
||
output_name="",
|
||
# Advanced options
|
||
grad_accum_steps=1,
|
||
num_workers=2,
|
||
log_interval=10,
|
||
valid_interval=1000,
|
||
weight_decay=0.01,
|
||
warmup_steps=100,
|
||
max_steps=None,
|
||
sample_rate=44100,
|
||
# LoRA advanced
|
||
enable_lm=True,
|
||
enable_dit=True,
|
||
enable_proj=False,
|
||
dropout=0.0,
|
||
tensorboard_path="",
|
||
# Distribution options
|
||
hf_model_id="",
|
||
distribute=False,
|
||
):
|
||
global training_process, training_log
|
||
|
||
if training_process is not None and training_process.poll() is None:
|
||
return "Training is already running!"
|
||
|
||
if output_name and output_name.strip():
|
||
timestamp = output_name.strip()
|
||
else:
|
||
timestamp = get_timestamp_str()
|
||
|
||
save_dir = os.path.join("lora", timestamp)
|
||
checkpoints_dir = os.path.join(save_dir, "checkpoints")
|
||
logs_dir = os.path.join(save_dir, "logs")
|
||
|
||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||
os.makedirs(logs_dir, exist_ok=True)
|
||
|
||
# Create config dictionary
|
||
# Resolve max_steps default
|
||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||
|
||
config = {
|
||
"pretrained_path": pretrained_path,
|
||
"train_manifest": train_manifest,
|
||
"val_manifest": val_manifest,
|
||
"sample_rate": int(sample_rate),
|
||
"batch_size": int(batch_size),
|
||
"grad_accum_steps": int(grad_accum_steps),
|
||
"num_workers": int(num_workers),
|
||
"num_iters": int(num_iters),
|
||
"log_interval": int(log_interval),
|
||
"valid_interval": int(valid_interval),
|
||
"save_interval": int(save_interval),
|
||
"learning_rate": float(learning_rate),
|
||
"weight_decay": float(weight_decay),
|
||
"warmup_steps": int(warmup_steps),
|
||
"max_steps": resolved_max_steps,
|
||
"save_path": checkpoints_dir,
|
||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||
"lambdas": {
|
||
"loss/diff": 1.0,
|
||
"loss/stop": 1.0
|
||
},
|
||
"lora": {
|
||
"enable_lm": bool(enable_lm),
|
||
"enable_dit": bool(enable_dit),
|
||
"enable_proj": bool(enable_proj),
|
||
"r": int(lora_rank),
|
||
"alpha": int(lora_alpha),
|
||
"dropout": float(dropout),
|
||
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||
},
|
||
}
|
||
|
||
# Add distribution options if provided
|
||
if hf_model_id and hf_model_id.strip():
|
||
config["hf_model_id"] = hf_model_id.strip()
|
||
if distribute:
|
||
config["distribute"] = True
|
||
|
||
config_path = os.path.join(save_dir, "train_config.yaml")
|
||
with open(config_path, "w") as f:
|
||
yaml.dump(config, f)
|
||
|
||
cmd = [
|
||
sys.executable,
|
||
"scripts/train_voxcpm_finetune.py",
|
||
"--config_path",
|
||
config_path
|
||
]
|
||
|
||
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
||
|
||
def run_process():
|
||
global training_process, training_log
|
||
training_process = subprocess.Popen(
|
||
cmd,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.STDOUT,
|
||
text=True,
|
||
bufsize=1
|
||
)
|
||
|
||
for line in training_process.stdout:
|
||
training_log += line
|
||
# Keep log size manageable
|
||
if len(training_log) > 100000:
|
||
training_log = training_log[-100000:]
|
||
|
||
training_process.wait()
|
||
training_log += f"\nTraining finished with code {training_process.returncode}"
|
||
|
||
threading.Thread(target=run_process, daemon=True).start()
|
||
|
||
return f"Training started! Check 'lora/{timestamp}'"
|
||
|
||
def get_training_log():
|
||
return training_log
|
||
|
||
def stop_training():
|
||
global training_process, training_log
|
||
if training_process is not None and training_process.poll() is None:
|
||
training_process.terminate()
|
||
training_log += "\nTraining terminated by user."
|
||
return "Training stopped."
|
||
return "No training running."
|
||
|
||
# --- GUI Layout ---
|
||
|
||
# 自定义CSS样式
|
||
custom_css = """
|
||
/* 整体主题样式 */
|
||
.gradio-container {
|
||
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||
}
|
||
|
||
/* 标题区域样式 - 扁平化设计 */
|
||
.title-section {
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
border-radius: 8px;
|
||
padding: 15px 25px;
|
||
margin-bottom: 15px;
|
||
border: none;
|
||
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
||
}
|
||
|
||
.title-section h1 {
|
||
color: white;
|
||
text-shadow: none;
|
||
font-weight: 600;
|
||
margin: 0;
|
||
font-size: 28px;
|
||
line-height: 1.2;
|
||
}
|
||
|
||
.title-section h3 {
|
||
color: rgba(255, 255, 255, 0.9);
|
||
font-weight: 400;
|
||
margin-top: 5px;
|
||
font-size: 14px;
|
||
line-height: 1.3;
|
||
}
|
||
|
||
.title-section p {
|
||
color: rgba(255, 255, 255, 0.85);
|
||
font-size: 13px;
|
||
margin: 5px 0 0 0;
|
||
line-height: 1.3;
|
||
}
|
||
|
||
/* 标签页样式 */
|
||
.tabs {
|
||
background: white;
|
||
border-radius: 15px;
|
||
padding: 10px;
|
||
box-shadow: 0 4px 20px rgba(0,0,0,0.08);
|
||
}
|
||
|
||
/* 按钮样式增强 */
|
||
.button-primary {
|
||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||
border: none;
|
||
border-radius: 12px;
|
||
padding: 12px 30px;
|
||
font-weight: 600;
|
||
color: white;
|
||
transition: all 0.3s ease;
|
||
box-shadow: 0 4px 15px rgba(102, 126, 234, 0.3);
|
||
}
|
||
|
||
.button-primary:hover {
|
||
transform: translateY(-2px);
|
||
box-shadow: 0 6px 25px rgba(102, 126, 234, 0.4);
|
||
}
|
||
|
||
.button-stop {
|
||
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
|
||
border: none;
|
||
border-radius: 12px;
|
||
padding: 12px 30px;
|
||
font-weight: 600;
|
||
color: white;
|
||
transition: all 0.3s ease;
|
||
box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3);
|
||
}
|
||
|
||
.button-stop:hover {
|
||
transform: translateY(-2px);
|
||
box-shadow: 0 6px 25px rgba(250, 112, 154, 0.4);
|
||
}
|
||
|
||
.button-refresh {
|
||
background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%);
|
||
border: none;
|
||
border-radius: 10px;
|
||
padding: 8px 20px;
|
||
font-weight: 500;
|
||
color: white;
|
||
transition: all 0.3s ease;
|
||
box-shadow: 0 2px 10px rgba(132, 250, 176, 0.3);
|
||
}
|
||
|
||
.button-refresh:hover {
|
||
transform: translateY(-1px);
|
||
box-shadow: 0 4px 15px rgba(132, 250, 176, 0.4);
|
||
}
|
||
|
||
/* 表单区域样式 */
|
||
.form-section {
|
||
background: white;
|
||
border-radius: 20px;
|
||
padding: 30px;
|
||
margin: 15px 0;
|
||
box-shadow: 0 8px 30px rgba(0,0,0,0.08);
|
||
border: 1px solid rgba(0,0,0,0.05);
|
||
}
|
||
|
||
/* 输入框样式 */
|
||
.input-field {
|
||
border-radius: 12px;
|
||
border: 2px solid #e0e0e0;
|
||
padding: 12px 16px;
|
||
transition: all 0.3s ease;
|
||
background: #fafafa;
|
||
}
|
||
|
||
.input-field:focus {
|
||
border-color: #667eea;
|
||
box-shadow: 0 0 0 4px rgba(102, 126, 234, 0.1);
|
||
background: white;
|
||
}
|
||
|
||
/* 滑块样式 */
|
||
.slider {
|
||
-webkit-appearance: none;
|
||
appearance: none;
|
||
width: 100%;
|
||
height: 6px;
|
||
border-radius: 3px;
|
||
background: linear-gradient(90deg, #667eea, #764ba2);
|
||
outline: none;
|
||
opacity: 0.8;
|
||
transition: opacity 0.2s;
|
||
}
|
||
|
||
.slider:hover {
|
||
opacity: 1;
|
||
}
|
||
|
||
.slider::-webkit-slider-thumb {
|
||
-webkit-appearance: none;
|
||
appearance: none;
|
||
width: 18px;
|
||
height: 18px;
|
||
border-radius: 50%;
|
||
background: white;
|
||
cursor: pointer;
|
||
border: 3px solid #667eea;
|
||
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
|
||
}
|
||
|
||
.slider::-moz-range-thumb {
|
||
width: 18px;
|
||
height: 18px;
|
||
border-radius: 50%;
|
||
background: white;
|
||
cursor: pointer;
|
||
border: 3px solid #667eea;
|
||
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.3);
|
||
}
|
||
|
||
/* 折叠面板样式 */
|
||
.accordion {
|
||
border-radius: 12px;
|
||
border: 2px solid #e0e0e0;
|
||
overflow: hidden;
|
||
background: white;
|
||
}
|
||
|
||
.accordion-header {
|
||
background: linear-gradient(135deg, #f5f7fa 0%, #e3e7ed 100%);
|
||
padding: 15px 20px;
|
||
font-weight: 600;
|
||
color: #333;
|
||
}
|
||
|
||
/* 状态显示样式 */
|
||
.status-success {
|
||
background: linear-gradient(135deg, #84fab0 0%, #8fd3f4 100%);
|
||
color: white;
|
||
padding: 12px 20px;
|
||
border-radius: 12px;
|
||
font-weight: 500;
|
||
box-shadow: 0 4px 15px rgba(132, 250, 176, 0.3);
|
||
}
|
||
|
||
.status-error {
|
||
background: linear-gradient(135deg, #fa709a 0%, #fee140 100%);
|
||
color: white;
|
||
padding: 12px 20px;
|
||
border-radius: 12px;
|
||
font-weight: 500;
|
||
box-shadow: 0 4px 15px rgba(250, 112, 154, 0.3);
|
||
}
|
||
|
||
/* 语言切换按钮样式 - 扁平化 */
|
||
.lang-selector {
|
||
background: rgba(255, 255, 255, 0.25);
|
||
backdrop-filter: blur(10px);
|
||
border-radius: 8px;
|
||
padding: 8px 12px;
|
||
border: 1px solid rgba(255, 255, 255, 0.4);
|
||
}
|
||
|
||
.lang-selector label.gr-box {
|
||
color: white !important;
|
||
font-weight: 600;
|
||
margin-bottom: 8px !important;
|
||
}
|
||
|
||
/* 单选按钮组样式 */
|
||
.lang-selector fieldset,
|
||
.lang-selector .gr-form {
|
||
gap: 10px !important;
|
||
display: flex !important;
|
||
}
|
||
|
||
/* 单选按钮容器 - 扁平化 (未选中状态 - 较浅的深色) */
|
||
.lang-selector label.gr-radio-label {
|
||
background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important;
|
||
border: 1px solid rgba(255, 255, 255, 0.5) !important;
|
||
border-radius: 6px !important;
|
||
padding: 8px 18px !important;
|
||
color: white !important;
|
||
font-weight: 500 !important;
|
||
transition: all 0.2s ease !important;
|
||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
|
||
cursor: pointer !important;
|
||
margin: 0 4px !important;
|
||
}
|
||
|
||
/* 选中的单选按钮 - 扁平化 (更深的深色背景) */
|
||
.lang-selector input[type="radio"]:checked + label,
|
||
.lang-selector label.gr-radio-label:has(input:checked) {
|
||
background: linear-gradient(135deg, #5568d3, #6b4c9a) !important;
|
||
color: white !important;
|
||
border: 1px solid rgba(255, 255, 255, 0.6) !important;
|
||
font-weight: 600 !important;
|
||
box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important;
|
||
transform: none !important;
|
||
}
|
||
|
||
/* 未选中的单选按钮悬停效果 - 扁平化 */
|
||
.lang-selector label.gr-radio-label:hover {
|
||
background: linear-gradient(135deg, rgba(102, 126, 234, 0.75), rgba(118, 75, 162, 0.75)) !important;
|
||
border-color: rgba(255, 255, 255, 0.7) !important;
|
||
transform: none !important;
|
||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.15) !important;
|
||
}
|
||
|
||
/* 隐藏原始的单选按钮圆点 */
|
||
.lang-selector input[type="radio"] {
|
||
opacity: 0;
|
||
position: absolute;
|
||
}
|
||
|
||
/* Gradio Radio 特定样式 - 扁平化 */
|
||
.lang-selector .wrap {
|
||
gap: 8px !important;
|
||
}
|
||
|
||
.lang-selector .wrap > label {
|
||
background: linear-gradient(135deg, rgba(102, 126, 234, 0.6), rgba(118, 75, 162, 0.6)) !important;
|
||
border: 1px solid rgba(255, 255, 255, 0.5) !important;
|
||
border-radius: 6px !important;
|
||
padding: 8px 18px !important;
|
||
color: white !important;
|
||
font-weight: 500 !important;
|
||
transition: all 0.2s ease !important;
|
||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1) !important;
|
||
}
|
||
|
||
.lang-selector .wrap > label.selected {
|
||
background: linear-gradient(135deg, #5568d3, #6b4c9a) !important;
|
||
color: white !important;
|
||
border: 1px solid rgba(255, 255, 255, 0.6) !important;
|
||
font-weight: 600 !important;
|
||
box-shadow: 0 3px 12px rgba(0, 0, 0, 0.2) !important;
|
||
}
|
||
|
||
/* 标签样式优化 */
|
||
label {
|
||
color: #333;
|
||
font-weight: 500;
|
||
margin-bottom: 8px;
|
||
}
|
||
|
||
/* Markdown 标题样式 */
|
||
.markdown-text h4 {
|
||
color: #667eea;
|
||
font-weight: 600;
|
||
margin-top: 15px;
|
||
margin-bottom: 10px;
|
||
}
|
||
|
||
/* 参数组件间距优化 */
|
||
.form-section > div {
|
||
margin-bottom: 15px;
|
||
}
|
||
|
||
/* Slider 组件样式优化 */
|
||
.gr-slider {
|
||
padding: 10px 0;
|
||
}
|
||
|
||
/* Number 输入框优化 */
|
||
.gr-number {
|
||
max-width: 100%;
|
||
}
|
||
|
||
/* 按钮容器优化 */
|
||
.gr-button {
|
||
min-height: 45px;
|
||
font-size: 16px;
|
||
}
|
||
|
||
/* 三栏布局优化 */
|
||
#component-0 .gr-row {
|
||
gap: 20px;
|
||
}
|
||
|
||
/* 生成按钮特殊样式 */
|
||
.button-primary.gr-button-lg {
|
||
min-height: 55px;
|
||
font-size: 18px;
|
||
font-weight: 700;
|
||
margin-top: 20px;
|
||
margin-bottom: 10px;
|
||
}
|
||
|
||
/* 刷新按钮小尺寸 */
|
||
.button-refresh.gr-button-sm {
|
||
min-height: 38px;
|
||
font-size: 14px;
|
||
margin-top: 5px;
|
||
margin-bottom: 15px;
|
||
}
|
||
|
||
/* 信息提示文字样式 */
|
||
.gr-info {
|
||
font-size: 13px;
|
||
color: #666;
|
||
margin-top: 5px;
|
||
}
|
||
|
||
/* 区域标题样式优化 */
|
||
.form-section h4 {
|
||
color: #667eea;
|
||
font-weight: 600;
|
||
margin-top: 0;
|
||
margin-bottom: 15px;
|
||
padding-bottom: 10px;
|
||
border-bottom: 2px solid #f0f0f0;
|
||
}
|
||
|
||
.form-section strong {
|
||
color: #667eea;
|
||
font-size: 15px;
|
||
display: block;
|
||
margin: 15px 0 10px 0;
|
||
}
|
||
"""
|
||
|
||
with gr.Blocks(
|
||
title="VoxCPM LoRA WebUI",
|
||
theme=gr.themes.Soft(),
|
||
css=custom_css
|
||
) as app:
|
||
|
||
# State for language
|
||
lang_state = gr.State("zh") # Default to Chinese
|
||
|
||
# 标题区域
|
||
with gr.Row(elem_classes="title-section"):
|
||
with gr.Column(scale=3):
|
||
title_md = gr.Markdown("""
|
||
# 🎵 VoxCPM LoRA WebUI
|
||
### 强大的语音合成和 LoRA 微调工具
|
||
|
||
支持语音克隆、LoRA 模型训练和推理的完整解决方案
|
||
""")
|
||
with gr.Column(scale=1):
|
||
lang_btn = gr.Radio(
|
||
choices=["en", "zh"],
|
||
value="zh",
|
||
label="🌐 Language / 语言",
|
||
elem_classes="lang-selector"
|
||
)
|
||
|
||
with gr.Tabs(elem_classes="tabs") as tabs:
|
||
# === Training Tab ===
|
||
with gr.Tab("🚀 训练 (Training)") as tab_train:
|
||
gr.Markdown("""
|
||
### 🎯 模型训练设置
|
||
配置你的 LoRA 微调训练参数
|
||
""")
|
||
|
||
with gr.Row():
|
||
with gr.Column(scale=2, elem_classes="form-section"):
|
||
gr.Markdown("#### 📁 基础配置")
|
||
|
||
train_pretrained_path = gr.Textbox(
|
||
label="📂 预训练模型路径",
|
||
value=default_pretrained_path,
|
||
elem_classes="input-field"
|
||
)
|
||
train_manifest = gr.Textbox(
|
||
label="📋 训练数据清单 (jsonl)",
|
||
value="examples/train_data_example.jsonl",
|
||
elem_classes="input-field"
|
||
)
|
||
val_manifest = gr.Textbox(
|
||
label="📊 验证数据清单 (可选)",
|
||
value="",
|
||
elem_classes="input-field"
|
||
)
|
||
|
||
gr.Markdown("#### ⚙️ 训练参数")
|
||
|
||
with gr.Row():
|
||
lr = gr.Number(
|
||
label="📈 学习率 (Learning Rate)",
|
||
value=1e-4,
|
||
elem_classes="input-field"
|
||
)
|
||
num_iters = gr.Number(
|
||
label="🔄 最大迭代次数",
|
||
value=2000,
|
||
precision=0,
|
||
elem_classes="input-field"
|
||
)
|
||
batch_size = gr.Number(
|
||
label="📦 批次大小 (Batch Size)",
|
||
value=1,
|
||
precision=0,
|
||
elem_classes="input-field"
|
||
)
|
||
|
||
with gr.Row():
|
||
lora_rank = gr.Number(
|
||
label="🎯 LoRA Rank",
|
||
value=32,
|
||
precision=0,
|
||
elem_classes="input-field"
|
||
)
|
||
lora_alpha = gr.Number(
|
||
label="⚖️ LoRA Alpha",
|
||
value=16,
|
||
precision=0,
|
||
elem_classes="input-field"
|
||
)
|
||
save_interval = gr.Number(
|
||
label="💾 保存间隔 (Steps)",
|
||
value=1000,
|
||
precision=0,
|
||
elem_classes="input-field"
|
||
)
|
||
|
||
output_name = gr.Textbox(
|
||
label="📁 输出目录名称 (可选,若存在则继续训练)",
|
||
value="",
|
||
elem_classes="input-field"
|
||
)
|
||
|
||
with gr.Row():
|
||
start_btn = gr.Button(
|
||
"▶️ 开始训练",
|
||
variant="primary",
|
||
elem_classes="button-primary"
|
||
)
|
||
stop_btn = gr.Button(
|
||
"⏹️ 停止训练",
|
||
variant="stop",
|
||
elem_classes="button-stop"
|
||
)
|
||
|
||
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
||
with gr.Row():
|
||
grad_accum_steps = gr.Number(label="梯度累积 (grad_accum_steps)", value=1, precision=0)
|
||
num_workers = gr.Number(label="数据加载线程 (num_workers)", value=2, precision=0)
|
||
log_interval = gr.Number(label="日志间隔 (log_interval)", value=10, precision=0)
|
||
with gr.Row():
|
||
valid_interval = gr.Number(label="验证间隔 (valid_interval)", value=1000, precision=0)
|
||
weight_decay = gr.Number(label="权重衰减 (weight_decay)", value=0.01)
|
||
warmup_steps = gr.Number(label="warmup_steps", value=100, precision=0)
|
||
with gr.Row():
|
||
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
||
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||
with gr.Row():
|
||
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||
|
||
gr.Markdown("#### 分发选项 (Distribution)")
|
||
with gr.Row():
|
||
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
|
||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||
|
||
with gr.Column(scale=2, elem_classes="form-section"):
|
||
gr.Markdown("#### 📊 训练日志")
|
||
logs_out = gr.TextArea(
|
||
label="",
|
||
lines=20,
|
||
max_lines=30,
|
||
interactive=False,
|
||
elem_classes="input-field",
|
||
show_label=False
|
||
)
|
||
|
||
start_btn.click(
|
||
start_training,
|
||
inputs=[
|
||
train_pretrained_path, train_manifest, val_manifest,
|
||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||
output_name,
|
||
# advanced
|
||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||
# distribution
|
||
hf_model_id, distribute
|
||
],
|
||
outputs=[logs_out] # Initial message
|
||
)
|
||
stop_btn.click(stop_training, outputs=[logs_out])
|
||
|
||
# Log refresher
|
||
timer = gr.Timer(1)
|
||
timer.tick(get_training_log, outputs=logs_out)
|
||
|
||
# === Inference Tab ===
|
||
with gr.Tab("🎵 推理 (Inference)") as tab_infer:
|
||
gr.Markdown("""
|
||
### 🎤 语音合成
|
||
使用训练好的 LoRA 模型生成语音,支持 LoRA 微调和声音克隆
|
||
""")
|
||
|
||
with gr.Row():
|
||
# 左栏:输入配置 (35%)
|
||
with gr.Column(scale=35, elem_classes="form-section"):
|
||
gr.Markdown("#### 📝 输入配置")
|
||
|
||
infer_text = gr.TextArea(
|
||
label="💬 合成文本",
|
||
value="Hello, this is a test of the VoxCPM LoRA model.",
|
||
elem_classes="input-field",
|
||
lines=4,
|
||
placeholder="输入要合成的文本内容..."
|
||
)
|
||
|
||
gr.Markdown("**🎭 声音克隆(可选)**")
|
||
|
||
prompt_wav = gr.Audio(
|
||
label="🎵 参考音频",
|
||
type="filepath",
|
||
elem_classes="input-field"
|
||
)
|
||
|
||
prompt_text = gr.Textbox(
|
||
label="📝 参考文本(可选)",
|
||
elem_classes="input-field",
|
||
placeholder="如不填写,将自动识别参考音频内容"
|
||
)
|
||
|
||
# 中栏:模型选择和参数配置 (35%)
|
||
with gr.Column(scale=35, elem_classes="form-section"):
|
||
gr.Markdown("#### 🤖 模型选择")
|
||
|
||
lora_select = gr.Dropdown(
|
||
label="🎯 LoRA 模型",
|
||
choices=["None"] + scan_lora_checkpoints(),
|
||
value="None",
|
||
interactive=True,
|
||
elem_classes="input-field",
|
||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
|
||
)
|
||
|
||
refresh_lora_btn = gr.Button(
|
||
"🔄 刷新模型列表",
|
||
elem_classes="button-refresh",
|
||
size="sm"
|
||
)
|
||
|
||
gr.Markdown("#### ⚙️ 生成参数")
|
||
|
||
cfg_scale = gr.Slider(
|
||
label="🎛️ CFG Scale",
|
||
minimum=1.0,
|
||
maximum=5.0,
|
||
value=2.0,
|
||
step=0.1,
|
||
info="引导系数,值越大越贴近提示"
|
||
)
|
||
|
||
steps = gr.Slider(
|
||
label="🔢 推理步数",
|
||
minimum=1,
|
||
maximum=50,
|
||
value=10,
|
||
step=1,
|
||
info="生成质量与步数成正比,但耗时更长"
|
||
)
|
||
|
||
seed = gr.Number(
|
||
label="🎲 随机种子",
|
||
value=-1,
|
||
precision=0,
|
||
elem_classes="input-field",
|
||
info="-1 为随机,固定值可复现结果"
|
||
)
|
||
|
||
generate_btn = gr.Button(
|
||
"🎵 生成音频",
|
||
variant="primary",
|
||
elem_classes="button-primary",
|
||
size="lg"
|
||
)
|
||
|
||
# 右栏:生成结果 (30%)
|
||
with gr.Column(scale=30, elem_classes="form-section"):
|
||
gr.Markdown("#### 🎧 生成结果")
|
||
|
||
audio_out = gr.Audio(
|
||
label="",
|
||
elem_classes="input-field",
|
||
show_label=False
|
||
)
|
||
|
||
gr.Markdown("#### 📋 状态信息")
|
||
|
||
status_out = gr.Textbox(
|
||
label="",
|
||
interactive=False,
|
||
elem_classes="input-field",
|
||
show_label=False,
|
||
lines=3,
|
||
placeholder="等待生成..."
|
||
)
|
||
|
||
def refresh_loras():
|
||
# 获取 LoRA checkpoints 及其 base model 信息
|
||
checkpoints_with_info = scan_lora_checkpoints(with_info=True)
|
||
choices = ["None"] + [ckpt[0] for ckpt in checkpoints_with_info]
|
||
|
||
# 输出调试信息
|
||
print(f"刷新 LoRA 列表: 找到 {len(checkpoints_with_info)} 个检查点")
|
||
for ckpt_path, base_model in checkpoints_with_info:
|
||
if base_model:
|
||
print(f" - {ckpt_path} (Base Model: {base_model})")
|
||
else:
|
||
print(f" - {ckpt_path}")
|
||
|
||
return gr.update(choices=choices, value="None")
|
||
|
||
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
||
|
||
# Auto-recognize audio when uploaded
|
||
prompt_wav.change(
|
||
fn=recognize_audio,
|
||
inputs=[prompt_wav],
|
||
outputs=[prompt_text]
|
||
)
|
||
|
||
generate_btn.click(
|
||
run_inference,
|
||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed],
|
||
outputs=[audio_out, status_out]
|
||
)
|
||
|
||
# --- Language Switching Logic ---
|
||
def change_language(lang):
|
||
d = LANG_DICT[lang]
|
||
# Labels for advanced options
|
||
if lang == "zh":
|
||
adv = {
|
||
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
|
||
'num_workers': "数据加载线程 (num_workers)",
|
||
'log_interval': "日志间隔 (log_interval)",
|
||
'valid_interval': "验证间隔 (valid_interval)",
|
||
'weight_decay': "权重衰减 (weight_decay)",
|
||
'warmup_steps': "warmup_steps",
|
||
'max_steps': "最大步数 (max_steps)",
|
||
'sample_rate': "采样率 (sample_rate)",
|
||
'enable_lm': "启用 LoRA LM (enable_lm)",
|
||
'enable_dit': "启用 LoRA DIT (enable_dit)",
|
||
'enable_proj': "启用投影 (enable_proj)",
|
||
'dropout': "LoRA Dropout",
|
||
'tensorboard_path': "Tensorboard 路径 (可选)",
|
||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||
'distribute': "分发模式 (distribute)",
|
||
}
|
||
else:
|
||
adv = {
|
||
'grad_accum_steps': "Grad Accum Steps",
|
||
'num_workers': "Num Workers",
|
||
'log_interval': "Log Interval",
|
||
'valid_interval': "Valid Interval",
|
||
'weight_decay': "Weight Decay",
|
||
'warmup_steps': "Warmup Steps",
|
||
'max_steps': "Max Steps",
|
||
'sample_rate': "Sample Rate",
|
||
'enable_lm': "Enable LoRA LM",
|
||
'enable_dit': "Enable LoRA DIT",
|
||
'enable_proj': "Enable Projection",
|
||
'dropout': "LoRA Dropout",
|
||
'tensorboard_path': "Tensorboard Path (Optional)",
|
||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||
'distribute': "Distribute Mode",
|
||
}
|
||
|
||
return (
|
||
gr.update(value=f"# {d['title']}"),
|
||
gr.update(label=d['tab_train']),
|
||
gr.update(label=d['tab_infer']),
|
||
gr.update(label=d['pretrained_path']),
|
||
gr.update(label=d['train_manifest']),
|
||
gr.update(label=d['val_manifest']),
|
||
gr.update(label=d['lr']),
|
||
gr.update(label=d['max_iters']),
|
||
gr.update(label=d['batch_size']),
|
||
gr.update(label=d['lora_rank']),
|
||
gr.update(label=d['lora_alpha']),
|
||
gr.update(label=d['save_interval']),
|
||
gr.update(label=d['output_name']),
|
||
gr.update(value=d['start_train']),
|
||
gr.update(value=d['stop_train']),
|
||
gr.update(label=d['train_logs']),
|
||
# Advanced options (must match outputs order)
|
||
gr.update(label=adv['grad_accum_steps']),
|
||
gr.update(label=adv['num_workers']),
|
||
gr.update(label=adv['log_interval']),
|
||
gr.update(label=adv['valid_interval']),
|
||
gr.update(label=adv['weight_decay']),
|
||
gr.update(label=adv['warmup_steps']),
|
||
gr.update(label=adv['max_steps']),
|
||
gr.update(label=adv['sample_rate']),
|
||
gr.update(label=adv['enable_lm']),
|
||
gr.update(label=adv['enable_dit']),
|
||
gr.update(label=adv['enable_proj']),
|
||
gr.update(label=adv['dropout']),
|
||
gr.update(label=adv['tensorboard_path']),
|
||
# Distribution options
|
||
gr.update(label=adv['hf_model_id']),
|
||
gr.update(label=adv['distribute']),
|
||
# Inference section
|
||
gr.update(label=d['text_to_synth']),
|
||
gr.update(label=d['ref_audio']),
|
||
gr.update(label=d['ref_text']),
|
||
gr.update(label=d['select_lora']),
|
||
gr.update(value=d['refresh']),
|
||
gr.update(label=d['cfg_scale']),
|
||
gr.update(label=d['infer_steps']),
|
||
gr.update(label=d['seed']),
|
||
gr.update(value=d['gen_audio']),
|
||
gr.update(label=d['gen_output']),
|
||
gr.update(label=d['status']),
|
||
)
|
||
|
||
lang_btn.change(
|
||
change_language,
|
||
inputs=[lang_btn],
|
||
outputs=[
|
||
title_md, tab_train, tab_infer,
|
||
train_pretrained_path, train_manifest, val_manifest,
|
||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
||
output_name,
|
||
start_btn, stop_btn, logs_out,
|
||
# advanced outputs
|
||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
||
weight_decay, warmup_steps, max_steps, sample_rate,
|
||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
||
# distribution outputs
|
||
hf_model_id, distribute,
|
||
infer_text, prompt_wav, prompt_text,
|
||
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
|
||
generate_btn, audio_out, status_out
|
||
]
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
# Ensure lora directory exists
|
||
os.makedirs("lora", exist_ok=True)
|
||
app.queue().launch(server_name="0.0.0.0", server_port=7860) |