mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bb6aaf41 | ||
|
|
2eb4d39719 | ||
|
|
fbf8984d4e | ||
|
|
41752dc0fa | ||
|
|
b0714adcaa | ||
|
|
89f4d917a0 | ||
|
|
5c5da0dbe6 | ||
|
|
961569e76d | ||
|
|
5f56d5ff5d | ||
|
|
169c17ddfd | ||
|
|
996c69a1a8 | ||
|
|
f26a1ea2f7 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
launch.json
|
||||
__pycache__
|
||||
voxcpm.egg-info
|
||||
52
README.md
52
README.md
@@ -1,7 +1,7 @@
|
||||
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
|
||||
|
||||
|
||||
[](https://github.com/OpenBMB/VoxCPM/) [](https://huggingface.co/openbmb/VoxCPM-0.5B) [](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://openbmb.github.io/VoxCPM-demopage)
|
||||
[](https://github.com/OpenBMB/VoxCPM/) [](https://arxiv.org/abs/2509.24650) [](https://huggingface.co/openbmb/VoxCPM-0.5B) [](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://openbmb.github.io/VoxCPM-demopage)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
</div>
|
||||
|
||||
## News
|
||||
* [2025.09.30] 🔥 🔥 🔥 We Release VoxCPM [Technical Report](https://arxiv.org/abs/2509.24650)!
|
||||
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)!
|
||||
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
|
||||
|
||||
@@ -50,7 +51,7 @@ By default, when you first run the script, the model will be downloaded automati
|
||||
- Download VoxCPM-0.5B
|
||||
```
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
|
||||
snapshot_download("openbmb/VoxCPM-0.5B")
|
||||
```
|
||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||
```
|
||||
@@ -62,10 +63,12 @@ By default, when you first run the script, the model will be downloaded automati
|
||||
### 2. Basic Usage
|
||||
```python
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from voxcpm import VoxCPM
|
||||
|
||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
||||
|
||||
# Non-streaming
|
||||
wav = model.generate(
|
||||
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
||||
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
||||
@@ -81,6 +84,18 @@ wav = model.generate(
|
||||
|
||||
sf.write("output.wav", wav, 16000)
|
||||
print("saved: output.wav")
|
||||
|
||||
# Streaming
|
||||
chunks = []
|
||||
for chunk in model.generate_streaming(
|
||||
text = "Streaming text to speech is easy with VoxCPM!",
|
||||
# supports same args as above
|
||||
):
|
||||
chunks.append(chunk)
|
||||
wav = np.concatenate(chunks)
|
||||
|
||||
sf.write("output_streaming.wav", wav, 16000)
|
||||
print("saved: output_streaming.wav")
|
||||
```
|
||||
|
||||
### 3. CLI Usage
|
||||
@@ -98,6 +113,13 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
|
||||
--output out.wav \
|
||||
--denoise
|
||||
|
||||
# (Optinal) Voice cloning (reference audio + transcript file)
|
||||
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-file "/path/to/text-file" \
|
||||
--output out.wav \
|
||||
--denoise
|
||||
|
||||
# 3) Batch processing (one text per line)
|
||||
voxcpm --input examples/input.txt --output-dir outs
|
||||
# (optional) Batch + cloning
|
||||
@@ -174,6 +196,19 @@ Happy creating! 🎉 Start with the default settings and tweak from there to sui
|
||||
---
|
||||
|
||||
|
||||
## 🌟 Community Projects
|
||||
|
||||
We're excited to see the VoxCPM community growing! Here are some amazing projects and features built by our community:
|
||||
|
||||
- **[ComfyUI-VoxCPM](https://github.com/wildminder/ComfyUI-VoxCPM)**
|
||||
- **[ComfyUI-VoxCPMTTS](https://github.com/1038lab/ComfyUI-VoxCPMTTS)**
|
||||
- **[WebUI-VoxCPM](https://github.com/rsxdalv/tts_webui_extension.vox_cpm)**
|
||||
- **[PR: Streaming API Support (by AbrahamSanders)](https://github.com/OpenBMB/VoxCPM/pull/26)**
|
||||
|
||||
|
||||
|
||||
*Have you built something cool with VoxCPM? We'd love to feature it here! Please open an issue or pull request to add your project.*
|
||||
|
||||
|
||||
## 📊 Performance Highlights
|
||||
|
||||
@@ -242,7 +277,7 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
||||
|
||||
## 📝TO-DO List
|
||||
Please stay tuned for updates!
|
||||
- [ ] Release the VoxCPM technical report.
|
||||
- [x] Release the VoxCPM technical report.
|
||||
- [ ] Support higher sampling rate (next version).
|
||||
|
||||
|
||||
@@ -273,16 +308,13 @@ This project is developed by the following institutions:
|
||||
|
||||
## 📚 Citation
|
||||
|
||||
The techical report is coming soon, please wait for the release 😊
|
||||
|
||||
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
|
||||
|
||||
```bib
|
||||
@misc{voxcpm2025,
|
||||
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}},
|
||||
title = {{VoxCPM}},
|
||||
@article{voxcpm2025,
|
||||
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning},
|
||||
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
|
||||
journal = {arXiv preprint arXiv:2509.24650},
|
||||
year = {2025},
|
||||
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
|
||||
note = {GitHub repository}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -20,12 +20,10 @@ classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.5.0",
|
||||
"torchaudio>=2.5.0",
|
||||
@@ -78,7 +76,7 @@ version_scheme = "post-release"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py38']
|
||||
target-version = ['py310']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
|
||||
@@ -240,6 +240,7 @@ Examples:
|
||||
# Prompt audio (for voice cloning)
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||
|
||||
# Generation parameters
|
||||
@@ -279,6 +280,12 @@ def main():
|
||||
|
||||
# If prompt audio+text provided → voice cloning
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
if not args.prompt_text and args.prompt_file:
|
||||
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
||||
|
||||
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
||||
args.prompt_text = f.read()
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
|
||||
@@ -11,6 +11,7 @@ class VoxCPM:
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
@@ -21,9 +22,10 @@ class VoxCPM:
|
||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||
self.text_normalizer = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
@@ -43,6 +45,7 @@ class VoxCPM:
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
@@ -54,6 +57,8 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
Kwargs:
|
||||
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||
|
||||
Returns:
|
||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||
@@ -82,9 +87,16 @@ class VoxCPM:
|
||||
voxcpm_model_path=local_path,
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate(self,
|
||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
def _generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
@@ -96,7 +108,8 @@ class VoxCPM:
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
@@ -118,8 +131,11 @@ class VoxCPM:
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
streaming: Whether to return a generator of audio chunks.
|
||||
Returns:
|
||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generations step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
@@ -155,7 +171,7 @@ class VoxCPM:
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
@@ -165,9 +181,11 @@ class VoxCPM:
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return wav.squeeze(0).cpu().numpy()
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
|
||||
@@ -19,11 +19,12 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Tuple, Union, Generator, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
@@ -84,11 +85,15 @@ class VoxCPMModel(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
@@ -99,7 +104,7 @@ class VoxCPMModel(nn.Module):
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
@@ -147,8 +152,10 @@ class VoxCPMModel(nn.Module):
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
def optimize(self, disable: bool = False):
|
||||
try:
|
||||
if disable:
|
||||
raise ValueError("Optimization disabled by user")
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
@@ -169,8 +176,14 @@ class VoxCPMModel(nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
def _generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
@@ -182,7 +195,11 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
@@ -258,14 +275,14 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
@@ -274,20 +291,31 @@ class VoxCPMModel(nn.Module):
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, _ in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
return decode_audio
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
@@ -377,8 +405,16 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
return merged_cache
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
@@ -389,7 +425,8 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
@@ -403,10 +440,17 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
streaming: Whether to return a generator of audio chunks
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
Generator of Tuple containing:
|
||||
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||
- Tensor of new text tokens
|
||||
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
@@ -444,14 +488,14 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
@@ -460,27 +504,48 @@ class VoxCPMModel(nn.Module):
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
def _inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
@@ -490,7 +555,8 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
streaming: bool = False,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
@@ -505,11 +571,12 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
streaming: Whether to yield each step latent feature or just the final result
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
Generator of Tuple containing:
|
||||
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
@@ -567,6 +634,12 @@ class VoxCPMModel(nn.Module):
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
@@ -581,13 +654,14 @@ class VoxCPMModel(nn.Module):
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
def from_local(cls, path: str, optimize: bool = True):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
@@ -600,7 +674,7 @@ class VoxCPMModel(nn.Module):
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
@@ -613,4 +687,4 @@ class VoxCPMModel(nn.Module):
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||
|
||||
@@ -154,6 +154,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_cache = key_cache.contiguous()
|
||||
value_cache = value_cache.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
|
||||
Reference in New Issue
Block a user