Files
VoxCPM-use/src/voxcpm/model/voxcpm.py
2025-09-18 12:01:26 +08:00

614 lines
23 KiB
Python

"""
VoxCPM: A Tokenizer-free speech generation model
This module contains the main VoxCPM model implementation, including configuration classes
and the core VoxCPMModel for text-to-speech generation.
Copyright 2025 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchaudio
from einops import rearrange
from pydantic import BaseModel
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.layers import ScalarQuantizationLayer
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
class VoxCPMEncoderConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
class VoxCPMDitConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
cfm_config: CfmConfig
class VoxCPMConfig(BaseModel):
lm_config: MiniCPM4Config
patch_size: int = 2
feat_dim: int = 64
residual_lm_num_layers: int = 6
scalar_quantization_latent_dim: int = 256
scalar_quantization_scale: int = 9
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
class VoxCPMModel(nn.Module):
def __init__(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
):
super().__init__()
self.config = config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
self.device = "cpu"
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
self.audio_end_token = 102
# Residual Acoustic LM
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
encoder_config.hidden_size = config.encoder_config.hidden_dim
encoder_config.intermediate_size = config.encoder_config.ffn_dim
encoder_config.num_attention_heads = config.encoder_config.num_heads
encoder_config.num_hidden_layers = config.encoder_config.num_layers
encoder_config.kv_channels = config.encoder_config.kv_channels
encoder_config.vocab_size = 0
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
# Local DiT
decoder_config = config.lm_config.model_copy(deep=True)
decoder_config.hidden_size = config.dit_config.hidden_dim
decoder_config.intermediate_size = config.dit_config.ffn_dim
decoder_config.num_attention_heads = config.dit_config.num_heads
decoder_config.num_hidden_layers = config.dit_config.num_layers
decoder_config.kv_channels = config.dit_config.kv_channels
decoder_config.vocab_size = 0
self.feat_decoder = UnifiedCFM(
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
)
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
# Stop Predictor
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
def optimize(self):
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(e)
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
return self
@torch.inference_mode()
def generate(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
):
if len(prompt_wav_path) == 0:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_length = text_token.shape[0]
audio_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
else:
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
@torch.inference_mode()
def build_prompt_cache(
self,
prompt_text: str,
prompt_wav_path: str,
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
prompt_cache = {
"text_token": text_token,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_audio_feat: torch.Tensor,
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
"""
if original_cache is None:
return {
"text_token": new_text_token,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"audio_feat": merged_audio_feat,
}
return merged_cache
@torch.inference_mode()
def generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
):
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
min_len: Minimum audio length to avoid very short audio
max_len: Maximum audio length
inference_timesteps: Number of diffusion sampling steps
cfg_value: Classifier-free guidance value
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
Returns:
tuple: (decoded audio tensor, new text tokens, new audio features)
"""
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
audio_pad_feat = torch.zeros(
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
return (
decode_audio,
target_text_token,
pred_audio_feat
)
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
feat: torch.Tensor,
feat_mask: torch.Tensor,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
feat: Input audio features
feat_mask: Mask for audio features
min_len: Minimum generation length
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
Returns:
Tuple containing:
- Predicted latent features
- Predicted audio feature sequence
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
prefix_feat_cond = feat[:, -1, ...] # b, p, d
pred_feat_seq = [] # b, t, p, d
curr_embed = None
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
)
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
pred_feat = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
n_timesteps=inference_timesteps,
cfg_value=cfg_value,
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
return feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae)
lm_dtype = get_dtype(config.dtype)
model = model.to(lm_dtype)
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
return model.to(model.device).eval().optimize()