Add a streaming API for VoxCPM

This commit is contained in:
AbrahamSanders
2025-09-19 16:56:11 -04:00
parent 5f56d5ff5d
commit 5c5da0dbe6
3 changed files with 144 additions and 53 deletions

View File

@@ -1,8 +1,8 @@
import torch
import torchaudio
import os
import re
import tempfile
import numpy as np
from typing import Generator
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel
@@ -11,6 +11,7 @@ class VoxCPM:
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
):
"""Initialize VoxCPM TTS pipeline.
@@ -21,9 +22,10 @@ class VoxCPM:
zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
@@ -43,6 +45,7 @@ class VoxCPM:
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -54,6 +57,8 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -82,9 +87,16 @@ class VoxCPM:
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
**kwargs,
)
def generate(self,
def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
@@ -96,7 +108,8 @@ class VoxCPM:
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
):
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
@@ -118,8 +131,11 @@ class VoxCPM:
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns:
numpy.ndarray: 1D waveform array (float32) on CPU.
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string")
@@ -155,7 +171,7 @@ class VoxCPM:
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=2,
@@ -165,9 +181,11 @@ class VoxCPM:
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
return wav.squeeze(0).cpu().numpy()
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):