mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
Add a streaming API for VoxCPM
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
launch.json
|
||||||
|
__pycache__
|
||||||
|
voxcpm.egg-info
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
from typing import Generator
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel
|
from .model.voxcpm import VoxCPMModel
|
||||||
|
|
||||||
@@ -11,6 +11,7 @@ class VoxCPM:
|
|||||||
voxcpm_model_path : str,
|
voxcpm_model_path : str,
|
||||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser : bool = True,
|
enable_denoiser : bool = True,
|
||||||
|
optimize: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize VoxCPM TTS pipeline.
|
"""Initialize VoxCPM TTS pipeline.
|
||||||
|
|
||||||
@@ -21,9 +22,10 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
from .zipenhancer import ZipEnhancer
|
from .zipenhancer import ZipEnhancer
|
||||||
@@ -43,6 +45,7 @@ class VoxCPM:
|
|||||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||||
|
|
||||||
@@ -54,6 +57,8 @@ class VoxCPM:
|
|||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
Kwargs:
|
||||||
|
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||||
@@ -82,9 +87,16 @@ class VoxCPM:
|
|||||||
voxcpm_model_path=local_path,
|
voxcpm_model_path=local_path,
|
||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(self,
|
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
|
def _generate(self,
|
||||||
text : str,
|
text : str,
|
||||||
prompt_wav_path : str = None,
|
prompt_wav_path : str = None,
|
||||||
prompt_text : str = None,
|
prompt_text : str = None,
|
||||||
@@ -96,7 +108,8 @@ class VoxCPM:
|
|||||||
retry_badcase : bool = True,
|
retry_badcase : bool = True,
|
||||||
retry_badcase_max_times : int = 3,
|
retry_badcase_max_times : int = 3,
|
||||||
retry_badcase_ratio_threshold : float = 6.0,
|
retry_badcase_ratio_threshold : float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[np.ndarray, None, None]:
|
||||||
"""Synthesize speech for the given text and return a single waveform.
|
"""Synthesize speech for the given text and return a single waveform.
|
||||||
|
|
||||||
This method optionally builds and reuses a prompt cache. If an external
|
This method optionally builds and reuses a prompt cache. If an external
|
||||||
@@ -118,8 +131,11 @@ class VoxCPM:
|
|||||||
retry_badcase: Whether to retry badcase.
|
retry_badcase: Whether to retry badcase.
|
||||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||||
|
streaming: Whether to return a generator of audio chunks.
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||||
|
Yields audio chunks for each generations step if ``streaming=True``,
|
||||||
|
otherwise yields a single array containing the final audio.
|
||||||
"""
|
"""
|
||||||
if not text.strip() or not isinstance(text, str):
|
if not text.strip() or not isinstance(text, str):
|
||||||
raise ValueError("target text must be a non-empty string")
|
raise ValueError("target text must be a non-empty string")
|
||||||
@@ -155,7 +171,7 @@ class VoxCPM:
|
|||||||
self.text_normalizer = TextNormalizer()
|
self.text_normalizer = TextNormalizer()
|
||||||
text = self.text_normalizer.normalize(text)
|
text = self.text_normalizer.normalize(text)
|
||||||
|
|
||||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||||
target_text=text,
|
target_text=text,
|
||||||
prompt_cache=fixed_prompt_cache,
|
prompt_cache=fixed_prompt_cache,
|
||||||
min_len=2,
|
min_len=2,
|
||||||
@@ -165,9 +181,11 @@ class VoxCPM:
|
|||||||
retry_badcase=retry_badcase,
|
retry_badcase=retry_badcase,
|
||||||
retry_badcase_max_times=retry_badcase_max_times,
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
return wav.squeeze(0).cpu().numpy()
|
for wav, _, _ in generate_result:
|
||||||
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Tuple, Union, Generator, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import warnings
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -147,8 +148,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.sample_rate = audio_vae.sample_rate
|
self.sample_rate = audio_vae.sample_rate
|
||||||
|
|
||||||
|
|
||||||
def optimize(self):
|
def optimize(self, disable: bool = False):
|
||||||
try:
|
try:
|
||||||
|
if disable:
|
||||||
|
raise ValueError("Optimization disabled by user")
|
||||||
if self.device != "cuda":
|
if self.device != "cuda":
|
||||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||||
try:
|
try:
|
||||||
@@ -169,8 +172,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_text: str = "",
|
prompt_text: str = "",
|
||||||
@@ -182,7 +191,11 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
if len(prompt_wav_path) == 0:
|
if len(prompt_wav_path) == 0:
|
||||||
text = target_text
|
text = target_text
|
||||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||||
@@ -265,7 +278,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -274,7 +287,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
if streaming:
|
||||||
|
patch_len = self.patch_size * self.chunk_size
|
||||||
|
for latent_pred, _ in inference_result:
|
||||||
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
|
yield decode_audio
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
@@ -285,9 +308,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
return decode_audio
|
yield decode_audio
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def build_prompt_cache(
|
def build_prompt_cache(
|
||||||
@@ -377,8 +401,16 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_with_prompt_cache_streaming(
|
||||||
|
self, *args, **kwargs
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_with_prompt_cache(
|
def _generate_with_prompt_cache(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_cache: dict,
|
prompt_cache: dict,
|
||||||
@@ -389,7 +421,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0,
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate audio using pre-built prompt cache.
|
Generate audio using pre-built prompt cache.
|
||||||
|
|
||||||
@@ -403,10 +436,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: Whether to retry on bad cases
|
retry_badcase: Whether to retry on bad cases
|
||||||
retry_badcase_max_times: Maximum retry attempts
|
retry_badcase_max_times: Maximum retry attempts
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||||
|
streaming: Whether to return a generator of audio chunks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
Generator of Tuple containing:
|
||||||
|
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||||
|
- Tensor of new text tokens
|
||||||
|
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
# get prompt from cache
|
# get prompt from cache
|
||||||
if prompt_cache is None:
|
if prompt_cache is None:
|
||||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||||
@@ -451,7 +491,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
target_text_length = len(self.text_tokenizer(target_text))
|
target_text_length = len(self.text_tokenizer(target_text))
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -460,7 +500,21 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
if streaming:
|
||||||
|
patch_len = self.patch_size * self.chunk_size
|
||||||
|
for latent_pred, pred_audio_feat in inference_result:
|
||||||
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
|
yield (
|
||||||
|
decode_audio,
|
||||||
|
target_text_token,
|
||||||
|
pred_audio_feat
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
@@ -470,17 +524,24 @@ class VoxCPMModel(nn.Module):
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
|
|
||||||
return (
|
yield (
|
||||||
decode_audio,
|
decode_audio,
|
||||||
target_text_token,
|
target_text_token,
|
||||||
pred_audio_feat
|
pred_audio_feat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._inference(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def _inference(
|
||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
text_mask: torch.Tensor,
|
text_mask: torch.Tensor,
|
||||||
@@ -490,7 +551,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: int = 2000,
|
max_len: int = 2000,
|
||||||
inference_timesteps: int = 10,
|
inference_timesteps: int = 10,
|
||||||
cfg_value: float = 2.0,
|
cfg_value: float = 2.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""Core inference method for audio generation.
|
"""Core inference method for audio generation.
|
||||||
|
|
||||||
This is the main inference loop that generates audio features
|
This is the main inference loop that generates audio features
|
||||||
@@ -505,11 +567,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: Maximum generation length
|
max_len: Maximum generation length
|
||||||
inference_timesteps: Number of diffusion steps
|
inference_timesteps: Number of diffusion steps
|
||||||
cfg_value: Classifier-free guidance value
|
cfg_value: Classifier-free guidance value
|
||||||
|
streaming: Whether to yield each step latent feature or just the final result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Generator of Tuple containing:
|
||||||
- Predicted latent features
|
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||||
- Predicted audio feature sequence
|
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
B, T, P, D = feat.shape
|
B, T, P, D = feat.shape
|
||||||
|
|
||||||
@@ -567,6 +630,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||||
|
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||||
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
break
|
break
|
||||||
@@ -581,13 +650,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
|
if not streaming:
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
|
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str):
|
def from_local(cls, path: str, optimize: bool = True):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
@@ -613,4 +683,4 @@ class VoxCPMModel(nn.Module):
|
|||||||
for kw, val in vae_state_dict.items():
|
for kw, val in vae_state_dict.items():
|
||||||
model_state_dict[f"audio_vae.{kw}"] = val
|
model_state_dict[f"audio_vae.{kw}"] = val
|
||||||
model.load_state_dict(model_state_dict, strict=True)
|
model.load_state_dict(model_state_dict, strict=True)
|
||||||
return model.to(model.device).eval().optimize()
|
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||||
|
|||||||
Reference in New Issue
Block a user