diff --git a/app.py b/app.py index eacba7d..3f64801 100644 --- a/app.py +++ b/app.py @@ -206,7 +206,7 @@ def create_demo_interface(demo: VoxCPMDemo): prompt_wav = gr.Audio( sources=["upload", 'microphone'], type="filepath", - label="Prompt Speech", + label="Prompt Speech (Optional, or let VoxCPM improvise)", value="./examples/example.wav", ) DoDenoisePromptAudio = gr.Checkbox( diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 533497d..7ff1d08 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -120,10 +120,17 @@ class VoxCPM: Returns: numpy.ndarray: 1D waveform array (float32) on CPU. """ - texts = text.split("\n") - texts = [t.strip() for t in texts if t.strip()] - final_wav = [] - temp_prompt_wav_path = None + if not text.strip() or not isinstance(text, str): + raise ValueError("target text must be a non-empty string") + + if prompt_wav_path is not None: + if not os.path.exists(prompt_wav_path): + raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}") + + if (prompt_wav_path is None) != (prompt_text is None): + raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None") + + temp_prompt_wav_path = None try: if prompt_wav_path is not None and prompt_text is not None: @@ -139,35 +146,25 @@ class VoxCPM: else: fixed_prompt_cache = None # will be built from the first inference - for sub_text in texts: - if sub_text.strip() == "": - continue - print("sub_text:", sub_text) - if normalize: - if self.text_normalizer is None: - from .utils.text_normalize import TextNormalizer - self.text_normalizer = TextNormalizer() - sub_text = self.text_normalizer.normalize(sub_text) - wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache( - target_text=sub_text, - prompt_cache=fixed_prompt_cache, - min_len=2, - max_len=max_length, - inference_timesteps=inference_timesteps, - cfg_value=cfg_value, - retry_badcase=retry_badcase, - retry_badcase_max_times=retry_badcase_max_times, - retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, - ) - if fixed_prompt_cache is None: - fixed_prompt_cache = self.tts_model.merge_prompt_cache( - original_cache=None, - new_text_token=target_text_token, - new_audio_feat=generated_audio_feat - ) - final_wav.append(wav) + if normalize: + if self.text_normalizer is None: + from .utils.text_normalize import TextNormalizer + self.text_normalizer = TextNormalizer() + text = self.text_normalizer.normalize(text) + + wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache( + target_text=text, + prompt_cache=fixed_prompt_cache, + min_len=2, + max_len=max_length, + inference_timesteps=inference_timesteps, + cfg_value=cfg_value, + retry_badcase=retry_badcase, + retry_badcase_max_times=retry_badcase_max_times, + retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, + ) - return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy() + return wav.squeeze(0).cpu().numpy() finally: if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 7268704..3af0af9 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -151,11 +151,16 @@ class VoxCPMModel(nn.Module): try: if self.device != "cuda": raise ValueError("VoxCPMModel can only be optimized on CUDA device") + try: + import triton + except: + raise ValueError("triton is not installed") self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) - except: + except Exception as e: + print(e) print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions") self.base_lm.forward_step = self.base_lm.forward_step self.residual_lm.forward_step = self.residual_lm.forward_step @@ -317,7 +322,7 @@ class VoxCPMModel(nn.Module): audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len)) # extract audio features - audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu() + audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu() audio_feat = audio_feat.view( self.audio_vae.latent_dim,