mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 19:58:12 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc6b6d1d1c | ||
|
|
cef6aefb3d | ||
|
|
1a46c5d1ad | ||
|
|
5257ec3dc5 | ||
|
|
bdd516b579 | ||
|
|
11568f0776 | ||
|
|
e5bcb735f0 | ||
|
|
1fa9e2ca02 |
@@ -267,6 +267,8 @@ This project is developed by the following institutions:
|
||||
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
||||
|
||||
|
||||
## ⭐ Star History
|
||||
[](https://star-history.com/#OpenBMB/VoxCPM&Date)
|
||||
|
||||
|
||||
## 📚 Citation
|
||||
|
||||
9
app.py
9
app.py
@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
**调低**:合成速度更快。
|
||||
- **Higher** for better synthesis quality.
|
||||
**调高**:合成质量更佳。
|
||||
|
||||
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
|
||||
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
|
||||
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
|
||||
""")
|
||||
|
||||
# Main controls
|
||||
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
type="filepath",
|
||||
label="Prompt Speech",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use WeTextPorcessing library to normalize the input text."
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ dependencies = [
|
||||
"addict",
|
||||
"wetext",
|
||||
"modelscope>=1.22.0",
|
||||
"datasets>=2,<4",
|
||||
"datasets>=3,<4",
|
||||
"huggingface-hub",
|
||||
"pydantic",
|
||||
"tqdm",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
@@ -120,10 +121,19 @@ class VoxCPM:
|
||||
Returns:
|
||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
"""
|
||||
texts = text.split("\n")
|
||||
texts = [t.strip() for t in texts if t.strip()]
|
||||
final_wav = []
|
||||
temp_prompt_wav_path = None
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
if not os.path.exists(prompt_wav_path):
|
||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||
|
||||
if (prompt_wav_path is None) != (prompt_text is None):
|
||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
@@ -139,35 +149,25 @@ class VoxCPM:
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
for sub_text in texts:
|
||||
if sub_text.strip() == "":
|
||||
continue
|
||||
print("sub_text:", sub_text)
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
self.text_normalizer = TextNormalizer()
|
||||
sub_text = self.text_normalizer.normalize(sub_text)
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
target_text=sub_text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
max_len=max_length,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
if fixed_prompt_cache is None:
|
||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
||||
original_cache=None,
|
||||
new_text_token=target_text_token,
|
||||
new_audio_feat=generated_audio_feat
|
||||
)
|
||||
final_wav.append(wav)
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
max_len=max_length,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
|
||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
||||
return wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
|
||||
@@ -151,12 +151,17 @@ class VoxCPMModel(nn.Module):
|
||||
try:
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
import triton
|
||||
except:
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
except:
|
||||
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
@@ -278,8 +283,11 @@ class VoxCPMModel(nn.Module):
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
break
|
||||
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
return decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
@@ -317,7 +325,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
@@ -463,7 +471,8 @@ class VoxCPMModel(nn.Module):
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
@@ -575,7 +584,6 @@ class VoxCPMModel(nn.Module):
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user