Merge pull request #26 from AbrahamSanders/main

Add a streaming API for VoxCPM
This commit is contained in:
xliucs
2025-09-22 20:47:07 +08:00
committed by GitHub
4 changed files with 158 additions and 53 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
launch.json
__pycache__
voxcpm.egg-info

View File

@@ -62,10 +62,12 @@ By default, when you first run the script, the model will be downloaded automati
### 2. Basic Usage ### 2. Basic Usage
```python ```python
import soundfile as sf import soundfile as sf
import numpy as np
from voxcpm import VoxCPM from voxcpm import VoxCPM
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B") model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
# Non-streaming
wav = model.generate( wav = model.generate(
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.", text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
@@ -81,6 +83,18 @@ wav = model.generate(
sf.write("output.wav", wav, 16000) sf.write("output.wav", wav, 16000)
print("saved: output.wav") print("saved: output.wav")
# Streaming
chunks = []
for chunk in model.generate_streaming(
text = "Streaming text to speech is easy with VoxCPM!",
# supports same args as above
):
chunks.append(chunk)
wav = np.concatenate(chunks)
sf.write("output_streaming.wav", wav, 16000)
print("saved: output_streaming.wav")
``` ```
### 3. CLI Usage ### 3. CLI Usage

View File

@@ -1,8 +1,8 @@
import torch
import torchaudio
import os import os
import re import re
import tempfile import tempfile
import numpy as np
from typing import Generator
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel from .model.voxcpm import VoxCPMModel
@@ -11,6 +11,7 @@ class VoxCPM:
voxcpm_model_path : str, voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True, enable_denoiser : bool = True,
optimize: bool = True,
): ):
"""Initialize VoxCPM TTS pipeline. """Initialize VoxCPM TTS pipeline.
@@ -21,9 +22,10 @@ class VoxCPM:
zipenhancer_model_path: ModelScope acoustic noise suppression model zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized. id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline. enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
""" """
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path) self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
self.text_normalizer = None self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None: if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer from .zipenhancer import ZipEnhancer
@@ -43,6 +45,7 @@ class VoxCPM:
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None, cache_dir: str = None,
local_files_only: bool = False, local_files_only: bool = False,
**kwargs,
): ):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -54,6 +57,8 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot. cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt local_files_only: If True, only use local files and do not attempt
to download. to download.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
Returns: Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -82,9 +87,16 @@ class VoxCPM:
voxcpm_model_path=local_path, voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser, enable_denoiser=load_denoiser,
**kwargs,
) )
def generate(self, def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str, text : str,
prompt_wav_path : str = None, prompt_wav_path : str = None,
prompt_text : str = None, prompt_text : str = None,
@@ -96,7 +108,8 @@ class VoxCPM:
retry_badcase : bool = True, retry_badcase : bool = True,
retry_badcase_max_times : int = 3, retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0, retry_badcase_ratio_threshold : float = 6.0,
): streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform. """Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external This method optionally builds and reuses a prompt cache. If an external
@@ -118,8 +131,11 @@ class VoxCPM:
retry_badcase: Whether to retry badcase. retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase. retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio. retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns: Returns:
numpy.ndarray: 1D waveform array (float32) on CPU. Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
otherwise yields a single array containing the final audio.
""" """
if not text.strip() or not isinstance(text, str): if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string") raise ValueError("target text must be a non-empty string")
@@ -155,7 +171,7 @@ class VoxCPM:
self.text_normalizer = TextNormalizer() self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text) text = self.text_normalizer.normalize(text)
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache( generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text, target_text=text,
prompt_cache=fixed_prompt_cache, prompt_cache=fixed_prompt_cache,
min_len=2, min_len=2,
@@ -165,9 +181,11 @@ class VoxCPM:
retry_badcase=retry_badcase, retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times, retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
) )
return wav.squeeze(0).cpu().numpy() for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally: finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):

View File

@@ -19,11 +19,12 @@ limitations under the License.
""" """
import os import os
from typing import Dict, Optional, Tuple, Union from typing import Tuple, Union, Generator, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
import warnings
from einops import rearrange from einops import rearrange
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
@@ -147,8 +148,10 @@ class VoxCPMModel(nn.Module):
self.sample_rate = audio_vae.sample_rate self.sample_rate = audio_vae.sample_rate
def optimize(self): def optimize(self, disable: bool = False):
try: try:
if disable:
raise ValueError("Optimization disabled by user")
if self.device != "cuda": if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device") raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try: try:
@@ -169,8 +172,14 @@ class VoxCPMModel(nn.Module):
return self return self
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def generate( def _generate(
self, self,
target_text: str, target_text: str,
prompt_text: str = "", prompt_text: str = "",
@@ -182,7 +191,11 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False, retry_badcase: bool = False,
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection) retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
): streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if len(prompt_wav_path) == 0: if len(prompt_wav_path) == 0:
text = target_text text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text)) text_token = torch.LongTensor(self.text_tokenizer(text))
@@ -265,7 +278,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_times = 0 retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times: while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference( inference_result = self._inference(
text_token, text_token,
text_mask, text_mask,
audio_feat, audio_feat,
@@ -274,20 +287,31 @@ class VoxCPMModel(nn.Module):
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming,
) )
if retry_badcase: if streaming:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: patch_len = self.patch_size * self.chunk_size
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") for latent_pred, _ in inference_result:
retry_badcase_times += 1 decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
continue decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else: else:
break break
else:
break
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() if not streaming:
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
return decode_audio decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode() @torch.inference_mode()
def build_prompt_cache( def build_prompt_cache(
@@ -377,8 +401,16 @@ class VoxCPMModel(nn.Module):
return merged_cache return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def generate_with_prompt_cache( def _generate_with_prompt_cache(
self, self,
target_text: str, target_text: str,
prompt_cache: dict, prompt_cache: dict,
@@ -389,7 +421,8 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False, retry_badcase: bool = False,
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, retry_badcase_ratio_threshold: float = 6.0,
): streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
""" """
Generate audio using pre-built prompt cache. Generate audio using pre-built prompt cache.
@@ -403,10 +436,17 @@ class VoxCPMModel(nn.Module):
retry_badcase: Whether to retry on bad cases retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
Returns: Returns:
tuple: (decoded audio tensor, new text tokens, new audio features) Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
""" """
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# get prompt from cache # get prompt from cache
if prompt_cache is None: if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32) prompt_text_token = torch.empty(0, dtype=torch.int32)
@@ -451,7 +491,7 @@ class VoxCPMModel(nn.Module):
target_text_length = len(self.text_tokenizer(target_text)) target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0 retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times: while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference( inference_result = self._inference(
text_token, text_token,
text_mask, text_mask,
audio_feat, audio_feat,
@@ -460,27 +500,48 @@ class VoxCPMModel(nn.Module):
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming,
) )
if retry_badcase: if streaming:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: patch_len = self.patch_size * self.chunk_size
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") for latent_pred, pred_audio_feat in inference_result:
retry_badcase_times += 1 decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
continue decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else: else:
break break
else: if not streaming:
break decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
return ( yield (
decode_audio, decode_audio,
target_text_token, target_text_token,
pred_audio_feat pred_audio_feat
) )
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def inference( def _inference(
self, self,
text: torch.Tensor, text: torch.Tensor,
text_mask: torch.Tensor, text_mask: torch.Tensor,
@@ -490,7 +551,8 @@ class VoxCPMModel(nn.Module):
max_len: int = 2000, max_len: int = 2000,
inference_timesteps: int = 10, inference_timesteps: int = 10,
cfg_value: float = 2.0, cfg_value: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]: streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation. """Core inference method for audio generation.
This is the main inference loop that generates audio features This is the main inference loop that generates audio features
@@ -505,11 +567,12 @@ class VoxCPMModel(nn.Module):
max_len: Maximum generation length max_len: Maximum generation length
inference_timesteps: Number of diffusion steps inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns: Returns:
Tuple containing: Generator of Tuple containing:
- Predicted latent features - Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
""" """
B, T, P, D = feat.shape B, T, P, D = feat.shape
@@ -567,6 +630,12 @@ class VoxCPMModel(nn.Module):
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1: if i > min_len and stop_flag == 1:
break break
@@ -581,13 +650,14 @@ class VoxCPMModel(nn.Module):
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device) lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone() ).clone()
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
return feat_pred, pred_feat_seq.squeeze(0).cpu() yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod @classmethod
def from_local(cls, path: str): def from_local(cls, path: str, optimize: bool = True):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path) tokenizer = LlamaTokenizerFast.from_pretrained(path)
@@ -613,4 +683,4 @@ class VoxCPMModel(nn.Module):
for kw, val in vae_state_dict.items(): for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True) model.load_state_dict(model_state_dict, strict=True)
return model.to(model.device).eval().optimize() return model.to(model.device).eval().optimize(disable=not optimize)