Files
VoxCPM-use/app.py
2025-12-11 00:12:18 +08:00

264 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import numpy as np
import torch
import gradio as gr
import spaces
from typing import Optional, Tuple
from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
import voxcpm
class VoxCPMDemo:
def __init__(self) -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {self.device}")
# ASR model for prompt text recognition
self.asr_model_id = "iic/SenseVoiceSmall"
self.asr_model: Optional[AutoModel] = AutoModel(
model=self.asr_model_id,
disable_update=True,
log_level='DEBUG',
device="cuda:0" if self.device == "cuda" else "cpu",
)
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM1.5"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
"""
Resolve model directory:
1) Use local checkpoint directory if exists
2) If HF_REPO_ID env is set, download into models/{repo}
3) Fallback to 'models'
"""
if os.path.isdir(self.default_local_model_dir):
return self.default_local_model_dir
repo_id = os.environ.get("HF_REPO_ID", "").strip()
if len(repo_id) > 0:
target_dir = os.path.join("models", repo_id.replace("/", "__"))
# Check if directory exists AND contains config.json
if not os.path.isdir(target_dir) or not os.path.exists(os.path.join(target_dir, "config.json")):
try:
from huggingface_hub import snapshot_download # type: ignore
os.makedirs(target_dir, exist_ok=True)
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
except Exception as e:
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
return "models"
return target_dir
return "models"
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
if self.voxcpm_model is not None:
return self.voxcpm_model
print("Model not loaded, initializing...")
model_dir = self._resolve_model_dir()
print(f"Using model dir: {model_dir}")
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
print("Model loaded successfully.")
return self.voxcpm_model
# ---------- Functional endpoints ----------
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
if prompt_wav is None:
return ""
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
text = res[0]["text"].split('|>')[-1]
return text
def generate_tts_audio(
self,
text_input: str,
prompt_wav_path_input: Optional[str] = None,
prompt_text_input: Optional[str] = None,
cfg_value_input: float = 2.0,
inference_timesteps_input: int = 10,
do_normalize: bool = True,
denoise: bool = True,
) -> Tuple[int, np.ndarray]:
"""
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
Returns (sample_rate, waveform_numpy)
"""
current_model = self.get_or_load_voxcpm()
text = (text_input or "").strip()
if len(text) == 0:
raise ValueError("Please input text to synthesize.")
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
prompt_text = prompt_text_input if prompt_text_input else None
print(f"Generating audio for text: '{text[:60]}...'")
wav = current_model.generate(
text=text,
prompt_text=prompt_text,
prompt_wav_path=prompt_wav_path,
cfg_value=float(cfg_value_input),
inference_timesteps=int(inference_timesteps_input),
normalize=do_normalize,
denoise=denoise,
)
return (current_model.tts_model.sample_rate, wav)
# ---------- UI Builders ----------
def create_demo_interface(demo: VoxCPMDemo):
"""Build the Gradio UI for VoxCPM demo."""
# static assets (logo path)
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="gray",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
),
css="""
.logo-container {
text-align: center;
margin: 0.5rem 0 1rem 0;
}
.logo-container img {
height: 80px;
width: auto;
max-width: 200px;
display: inline-block;
}
/* Bold accordion labels */
#acc_quick details > summary,
#acc_tips details > summary {
font-weight: 600 !important;
font-size: 1.1em !important;
}
/* Bold labels for specific checkboxes */
#chk_denoise label,
#chk_denoise span,
#chk_normalize label,
#chk_normalize span {
font-weight: 600;
}
"""
) as interface:
# Header logo
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
# Quick Start
with gr.Accordion("📋 快速入门", open=False, elem_id="acc_quick"):
gr.Markdown("""
### 使用说明
1. **(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征。
2. **(可选)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
3. **输入目标文本** - 输入您希望模型朗读的文字内容。
4. **生成语音** - 点击"生成语音"按钮,即可为您创造出音频。
""")
# Pro Tips
with gr.Accordion("💡 使用建议", open=False, elem_id="acc_tips"):
gr.Markdown("""
### 参考语音降噪
- **启用**:通过 ZipEnhancer 组件消除背景噪音但会将音频采样率限制在16kHz限制克隆上限。
- **禁用**保留原始音频的全部信息包括背景环境声最高支持44.1kHz的音频复刻。
### 文本正则化
- **启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。
- **禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3}英文转CMUDict{HH AH0 L OW1})和公式符号合成,尝试一下!
### CFG 值
- **调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。
- **调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。
### 推理时间步
- **调低**:合成速度更快。
- **调高**:合成质量更佳。
""")
# Main controls
with gr.Row():
with gr.Column():
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
type="filepath",
label="参考语音(可选,或让 VoxCPM 自由发挥)",
value="./examples/example.wav",
)
DoDenoisePromptAudio = gr.Checkbox(
value=False,
label="参考语音增强",
elem_id="chk_denoise",
info="使用 ZipEnhancer 模型对参考音频进行降噪。"
)
with gr.Row():
prompt_text = gr.Textbox(
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
label="参考文本",
placeholder="请输入参考文本。支持自动识别,您也可以自行修改结果..."
)
run_btn = gr.Button("生成语音", variant="primary")
with gr.Column():
cfg_value = gr.Slider(
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="CFG 值 (引导比例)",
info="值越高越贴合提示,值越低允许更多的创造性"
)
inference_timesteps = gr.Slider(
minimum=4,
maximum=30,
value=10,
step=1,
label="推理时间步",
info="生成的推理时间步数(值越高可能质量越好,但速度更慢)"
)
with gr.Row():
text = gr.Textbox(
value="VoxCPM 是 ModelBest 推出的一款创新型端到端 TTS 模型,旨在生成极具表现力的语音。",
label="目标文本",
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="文本正则化",
elem_id="chk_normalize",
info="使用 wetext 库对输入文本进行标准化。"
)
audio_output = gr.Audio(label="输出音频")
# Wiring
run_btn.click(
fn=demo.generate_tts_audio,
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
outputs=[audio_output],
show_progress=True,
api_name="generate",
)
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
return interface
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
if __name__ == "__main__":
run_demo()