Remove segment text logic
This commit is contained in:
2
app.py
2
app.py
@@ -206,7 +206,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
prompt_wav = gr.Audio(
|
prompt_wav = gr.Audio(
|
||||||
sources=["upload", 'microphone'],
|
sources=["upload", 'microphone'],
|
||||||
type="filepath",
|
type="filepath",
|
||||||
label="Prompt Speech",
|
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||||
value="./examples/example.wav",
|
value="./examples/example.wav",
|
||||||
)
|
)
|
||||||
DoDenoisePromptAudio = gr.Checkbox(
|
DoDenoisePromptAudio = gr.Checkbox(
|
||||||
|
|||||||
@@ -120,9 +120,16 @@ class VoxCPM:
|
|||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||||
"""
|
"""
|
||||||
texts = text.split("\n")
|
if not text.strip() or not isinstance(text, str):
|
||||||
texts = [t.strip() for t in texts if t.strip()]
|
raise ValueError("target text must be a non-empty string")
|
||||||
final_wav = []
|
|
||||||
|
if prompt_wav_path is not None:
|
||||||
|
if not os.path.exists(prompt_wav_path):
|
||||||
|
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||||
|
|
||||||
|
if (prompt_wav_path is None) != (prompt_text is None):
|
||||||
|
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||||
|
|
||||||
temp_prompt_wav_path = None
|
temp_prompt_wav_path = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -139,35 +146,25 @@ class VoxCPM:
|
|||||||
else:
|
else:
|
||||||
fixed_prompt_cache = None # will be built from the first inference
|
fixed_prompt_cache = None # will be built from the first inference
|
||||||
|
|
||||||
for sub_text in texts:
|
if normalize:
|
||||||
if sub_text.strip() == "":
|
if self.text_normalizer is None:
|
||||||
continue
|
from .utils.text_normalize import TextNormalizer
|
||||||
print("sub_text:", sub_text)
|
self.text_normalizer = TextNormalizer()
|
||||||
if normalize:
|
text = self.text_normalizer.normalize(text)
|
||||||
if self.text_normalizer is None:
|
|
||||||
from .utils.text_normalize import TextNormalizer
|
|
||||||
self.text_normalizer = TextNormalizer()
|
|
||||||
sub_text = self.text_normalizer.normalize(sub_text)
|
|
||||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
|
||||||
target_text=sub_text,
|
|
||||||
prompt_cache=fixed_prompt_cache,
|
|
||||||
min_len=2,
|
|
||||||
max_len=max_length,
|
|
||||||
inference_timesteps=inference_timesteps,
|
|
||||||
cfg_value=cfg_value,
|
|
||||||
retry_badcase=retry_badcase,
|
|
||||||
retry_badcase_max_times=retry_badcase_max_times,
|
|
||||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
|
||||||
)
|
|
||||||
if fixed_prompt_cache is None:
|
|
||||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
|
||||||
original_cache=None,
|
|
||||||
new_text_token=target_text_token,
|
|
||||||
new_audio_feat=generated_audio_feat
|
|
||||||
)
|
|
||||||
final_wav.append(wav)
|
|
||||||
|
|
||||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||||
|
target_text=text,
|
||||||
|
prompt_cache=fixed_prompt_cache,
|
||||||
|
min_len=2,
|
||||||
|
max_len=max_length,
|
||||||
|
inference_timesteps=inference_timesteps,
|
||||||
|
cfg_value=cfg_value,
|
||||||
|
retry_badcase=retry_badcase,
|
||||||
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||||
|
|||||||
@@ -151,11 +151,16 @@ class VoxCPMModel(nn.Module):
|
|||||||
try:
|
try:
|
||||||
if self.device != "cuda":
|
if self.device != "cuda":
|
||||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
except:
|
||||||
|
raise ValueError("triton is not installed")
|
||||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||||
self.base_lm.forward_step = self.base_lm.forward_step
|
self.base_lm.forward_step = self.base_lm.forward_step
|
||||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||||
@@ -317,7 +322,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||||
|
|
||||||
# extract audio features
|
# extract audio features
|
||||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||||
|
|
||||||
audio_feat = audio_feat.view(
|
audio_feat = audio_feat.view(
|
||||||
self.audio_vae.latent_dim,
|
self.audio_vae.latent_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user