Add a streaming API for VoxCPM
This commit is contained in:
@@ -19,11 +19,12 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Tuple, Union, Generator, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
@@ -147,8 +148,10 @@ class VoxCPMModel(nn.Module):
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
def optimize(self, disable: bool = False):
|
||||
try:
|
||||
if disable:
|
||||
raise ValueError("Optimization disabled by user")
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
@@ -169,8 +172,14 @@ class VoxCPMModel(nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
def _generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
@@ -182,7 +191,11 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
@@ -265,7 +278,7 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
@@ -274,20 +287,31 @@ class VoxCPMModel(nn.Module):
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, _ in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
else:
|
||||
break
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
return decode_audio
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
@@ -376,9 +400,17 @@ class VoxCPMModel(nn.Module):
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
@@ -389,7 +421,8 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
@@ -403,10 +436,17 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
streaming: Whether to return a generator of audio chunks
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
Generator of Tuple containing:
|
||||
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||
- Tensor of new text tokens
|
||||
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
@@ -451,7 +491,7 @@ class VoxCPMModel(nn.Module):
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
@@ -460,27 +500,48 @@ class VoxCPMModel(nn.Module):
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
def _inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
@@ -490,7 +551,8 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
streaming: bool = False,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
@@ -505,11 +567,12 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
streaming: Whether to yield each step latent feature or just the final result
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
Generator of Tuple containing:
|
||||
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
@@ -566,6 +629,12 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
@@ -581,13 +650,14 @@ class VoxCPMModel(nn.Module):
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
def from_local(cls, path: str, optimize: bool = True):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
@@ -613,4 +683,4 @@ class VoxCPMModel(nn.Module):
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||
|
||||
Reference in New Issue
Block a user