From 5c5da0dbe68e739e8ad09244f7f03ac0f9580ac7 Mon Sep 17 00:00:00 2001 From: AbrahamSanders Date: Fri, 19 Sep 2025 16:56:11 -0400 Subject: [PATCH 1/2] Add a streaming API for VoxCPM --- .gitignore | 3 + src/voxcpm/core.py | 34 ++++++-- src/voxcpm/model/voxcpm.py | 160 ++++++++++++++++++++++++++----------- 3 files changed, 144 insertions(+), 53 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f685e73 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +launch.json +__pycache__ +voxcpm.egg-info \ No newline at end of file diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 3b88b55..e8d22fe 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -1,8 +1,8 @@ -import torch -import torchaudio import os import re import tempfile +import numpy as np +from typing import Generator from huggingface_hub import snapshot_download from .model.voxcpm import VoxCPMModel @@ -11,6 +11,7 @@ class VoxCPM: voxcpm_model_path : str, zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", enable_denoiser : bool = True, + optimize: bool = True, ): """Initialize VoxCPM TTS pipeline. @@ -21,9 +22,10 @@ class VoxCPM: zipenhancer_model_path: ModelScope acoustic noise suppression model id or local path. If None, denoiser will not be initialized. enable_denoiser: Whether to initialize the denoiser pipeline. + optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. """ print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") - self.tts_model = VoxCPMModel.from_local(voxcpm_model_path) + self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize) self.text_normalizer = None if enable_denoiser and zipenhancer_model_path is not None: from .zipenhancer import ZipEnhancer @@ -43,6 +45,7 @@ class VoxCPM: zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base", cache_dir: str = None, local_files_only: bool = False, + **kwargs, ): """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. @@ -54,6 +57,8 @@ class VoxCPM: cache_dir: Custom cache directory for the snapshot. local_files_only: If True, only use local files and do not attempt to download. + Kwargs: + Additional keyword arguments passed to the ``VoxCPM`` constructor. Returns: VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to @@ -82,9 +87,16 @@ class VoxCPM: voxcpm_model_path=local_path, zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, enable_denoiser=load_denoiser, + **kwargs, ) - def generate(self, + def generate(self, *args, **kwargs) -> np.ndarray: + return next(self._generate(*args, streaming=False, **kwargs)) + + def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]: + return self._generate(*args, streaming=True, **kwargs) + + def _generate(self, text : str, prompt_wav_path : str = None, prompt_text : str = None, @@ -96,7 +108,8 @@ class VoxCPM: retry_badcase : bool = True, retry_badcase_max_times : int = 3, retry_badcase_ratio_threshold : float = 6.0, - ): + streaming: bool = False, + ) -> Generator[np.ndarray, None, None]: """Synthesize speech for the given text and return a single waveform. This method optionally builds and reuses a prompt cache. If an external @@ -118,8 +131,11 @@ class VoxCPM: retry_badcase: Whether to retry badcase. retry_badcase_max_times: Maximum number of times to retry badcase. retry_badcase_ratio_threshold: Threshold for audio-to-text ratio. + streaming: Whether to return a generator of audio chunks. Returns: - numpy.ndarray: 1D waveform array (float32) on CPU. + Generator of numpy.ndarray: 1D waveform array (float32) on CPU. + Yields audio chunks for each generations step if ``streaming=True``, + otherwise yields a single array containing the final audio. """ if not text.strip() or not isinstance(text, str): raise ValueError("target text must be a non-empty string") @@ -155,7 +171,7 @@ class VoxCPM: self.text_normalizer = TextNormalizer() text = self.text_normalizer.normalize(text) - wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache( + generate_result = self.tts_model._generate_with_prompt_cache( target_text=text, prompt_cache=fixed_prompt_cache, min_len=2, @@ -165,9 +181,11 @@ class VoxCPM: retry_badcase=retry_badcase, retry_badcase_max_times=retry_badcase_max_times, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + streaming=streaming, ) - return wav.squeeze(0).cpu().numpy() + for wav, _, _ in generate_result: + yield wav.squeeze(0).cpu().numpy() finally: if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 1f5fdec..89b895b 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -19,11 +19,12 @@ limitations under the License. """ import os -from typing import Dict, Optional, Tuple, Union +from typing import Tuple, Union, Generator, List import torch import torch.nn as nn import torchaudio +import warnings from einops import rearrange from pydantic import BaseModel from tqdm import tqdm @@ -147,8 +148,10 @@ class VoxCPMModel(nn.Module): self.sample_rate = audio_vae.sample_rate - def optimize(self): + def optimize(self, disable: bool = False): try: + if disable: + raise ValueError("Optimization disabled by user") if self.device != "cuda": raise ValueError("VoxCPMModel can only be optimized on CUDA device") try: @@ -169,8 +172,14 @@ class VoxCPMModel(nn.Module): return self + def generate(self, *args, **kwargs) -> torch.Tensor: + return next(self._generate(*args, streaming=False, **kwargs)) + + def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]: + return self._generate(*args, streaming=True, **kwargs) + @torch.inference_mode() - def generate( + def _generate( self, target_text: str, prompt_text: str = "", @@ -182,7 +191,11 @@ class VoxCPMModel(nn.Module): retry_badcase: bool = False, retry_badcase_max_times: int = 3, retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection) - ): + streaming: bool = False, + ) -> Generator[torch.Tensor, None, None]: + if retry_badcase and streaming: + warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") + retry_badcase = False if len(prompt_wav_path) == 0: text = target_text text_token = torch.LongTensor(self.text_tokenizer(text)) @@ -265,7 +278,7 @@ class VoxCPMModel(nn.Module): retry_badcase_times = 0 while retry_badcase_times < retry_badcase_max_times: - latent_pred, pred_audio_feat = self.inference( + inference_result = self._inference( text_token, text_mask, audio_feat, @@ -274,20 +287,31 @@ class VoxCPMModel(nn.Module): max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, inference_timesteps=inference_timesteps, cfg_value=cfg_value, + streaming=streaming, ) - if retry_badcase: - if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: - print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") - retry_badcase_times += 1 - continue - else: - break + if streaming: + patch_len = self.patch_size * self.chunk_size + for latent_pred, _ in inference_result: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() + yield decode_audio + break else: - break + latent_pred, pred_audio_feat = next(inference_result) + if retry_badcase: + if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: + print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") + retry_badcase_times += 1 + continue + else: + break + else: + break - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() - decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio - return decode_audio + if not streaming: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() + decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio + yield decode_audio @torch.inference_mode() def build_prompt_cache( @@ -376,9 +400,17 @@ class VoxCPMModel(nn.Module): } return merged_cache - + + def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs)) + + def generate_with_prompt_cache_streaming( + self, *args, **kwargs + ) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]: + return self._generate_with_prompt_cache(*args, streaming=True, **kwargs) + @torch.inference_mode() - def generate_with_prompt_cache( + def _generate_with_prompt_cache( self, target_text: str, prompt_cache: dict, @@ -389,7 +421,8 @@ class VoxCPMModel(nn.Module): retry_badcase: bool = False, retry_badcase_max_times: int = 3, retry_badcase_ratio_threshold: float = 6.0, - ): + streaming: bool = False, + ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: """ Generate audio using pre-built prompt cache. @@ -403,10 +436,17 @@ class VoxCPMModel(nn.Module): retry_badcase: Whether to retry on bad cases retry_badcase_max_times: Maximum retry attempts retry_badcase_ratio_threshold: Threshold for audio-to-text ratio + streaming: Whether to return a generator of audio chunks Returns: - tuple: (decoded audio tensor, new text tokens, new audio features) + Generator of Tuple containing: + - Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor + - Tensor of new text tokens + - New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor """ + if retry_badcase and streaming: + warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") + retry_badcase = False # get prompt from cache if prompt_cache is None: prompt_text_token = torch.empty(0, dtype=torch.int32) @@ -451,7 +491,7 @@ class VoxCPMModel(nn.Module): target_text_length = len(self.text_tokenizer(target_text)) retry_badcase_times = 0 while retry_badcase_times < retry_badcase_max_times: - latent_pred, pred_audio_feat = self.inference( + inference_result = self._inference( text_token, text_mask, audio_feat, @@ -460,27 +500,48 @@ class VoxCPMModel(nn.Module): max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, inference_timesteps=inference_timesteps, cfg_value=cfg_value, + streaming=streaming, ) - if retry_badcase: - if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: - print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") - retry_badcase_times += 1 - continue + if streaming: + patch_len = self.patch_size * self.chunk_size + for latent_pred, pred_audio_feat in inference_result: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() + yield ( + decode_audio, + target_text_token, + pred_audio_feat + ) + break + else: + latent_pred, pred_audio_feat = next(inference_result) + if retry_badcase: + if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: + print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") + retry_badcase_times += 1 + continue + else: + break else: break - else: - break - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() - decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio + if not streaming: + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() + decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio - return ( - decode_audio, - target_text_token, - pred_audio_feat - ) + yield ( + decode_audio, + target_text_token, + pred_audio_feat + ) + + def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + return next(self._inference(*args, streaming=False, **kwargs)) + + def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]: + return self._inference(*args, streaming=True, **kwargs) @torch.inference_mode() - def inference( + def _inference( self, text: torch.Tensor, text_mask: torch.Tensor, @@ -490,7 +551,8 @@ class VoxCPMModel(nn.Module): max_len: int = 2000, inference_timesteps: int = 10, cfg_value: float = 2.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: + streaming: bool = False, + ) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: """Core inference method for audio generation. This is the main inference loop that generates audio features @@ -505,11 +567,12 @@ class VoxCPMModel(nn.Module): max_len: Maximum generation length inference_timesteps: Number of diffusion steps cfg_value: Classifier-free guidance value + streaming: Whether to yield each step latent feature or just the final result Returns: - Tuple containing: - - Predicted latent features - - Predicted audio feature sequence + Generator of Tuple containing: + - Predicted latent feature at the current step if ``streaming=True``, else final latent features + - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor """ B, T, P, D = feat.shape @@ -566,6 +629,12 @@ class VoxCPMModel(nn.Module): pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d prefix_feat_cond = pred_feat + + if streaming: + # return the last three predicted latent features to provide enough context for smooth decoding + pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1) + feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size) + yield feat_pred, pred_feat_seq stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() if i > min_len and stop_flag == 1: @@ -581,13 +650,14 @@ class VoxCPMModel(nn.Module): lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device) ).clone() - pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d + if not streaming: + pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d - feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) - return feat_pred, pred_feat_seq.squeeze(0).cpu() + feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) + yield feat_pred, pred_feat_seq.squeeze(0).cpu() @classmethod - def from_local(cls, path: str): + def from_local(cls, path: str, optimize: bool = True): config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) tokenizer = LlamaTokenizerFast.from_pretrained(path) @@ -613,4 +683,4 @@ class VoxCPMModel(nn.Module): for kw, val in vae_state_dict.items(): model_state_dict[f"audio_vae.{kw}"] = val model.load_state_dict(model_state_dict, strict=True) - return model.to(model.device).eval().optimize() + return model.to(model.device).eval().optimize(disable=not optimize) From 89f4d917a0d77eeabdfd40f433cd9d07ae7140b1 Mon Sep 17 00:00:00 2001 From: AbrahamSanders Date: Fri, 19 Sep 2025 17:09:30 -0400 Subject: [PATCH 2/2] Update readme with streaming example --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index f81bc57..cf9d9cf 100644 --- a/README.md +++ b/README.md @@ -62,10 +62,12 @@ By default, when you first run the script, the model will be downloaded automati ### 2. Basic Usage ```python import soundfile as sf +import numpy as np from voxcpm import VoxCPM model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B") +# Non-streaming wav = model.generate( text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.", prompt_wav_path=None, # optional: path to a prompt speech for voice cloning @@ -81,6 +83,18 @@ wav = model.generate( sf.write("output.wav", wav, 16000) print("saved: output.wav") + +# Streaming +chunks = [] +for chunk in model.generate_streaming( + text = "Streaming text to speech is easy with VoxCPM!", + # supports same args as above +): + chunks.append(chunk) +wav = np.concatenate(chunks) + +sf.write("output_streaming.wav", wav, 16000) +print("saved: output_streaming.wav") ``` ### 3. CLI Usage