614 lines
23 KiB
Python
614 lines
23 KiB
Python
"""
|
|
VoxCPM: A Tokenizer-free speech generation model
|
|
|
|
This module contains the main VoxCPM model implementation, including configuration classes
|
|
and the core VoxCPMModel for text-to-speech generation.
|
|
|
|
Copyright 2025 OpenBMB
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import os
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchaudio
|
|
from einops import rearrange
|
|
from pydantic import BaseModel
|
|
from tqdm import tqdm
|
|
from transformers import LlamaTokenizerFast
|
|
|
|
from ..modules.audiovae import AudioVAE
|
|
from ..modules.layers import ScalarQuantizationLayer
|
|
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
|
from ..modules.locenc import VoxCPMLocEnc
|
|
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
|
from .utils import get_dtype, mask_multichar_chinese_tokens
|
|
|
|
|
|
class VoxCPMEncoderConfig(BaseModel):
|
|
hidden_dim: int = 1024
|
|
ffn_dim: int = 4096
|
|
num_heads: int = 16
|
|
num_layers: int = 4
|
|
kv_channels: int = None
|
|
|
|
|
|
class VoxCPMDitConfig(BaseModel):
|
|
hidden_dim: int = 1024
|
|
ffn_dim: int = 4096
|
|
num_heads: int = 16
|
|
num_layers: int = 4
|
|
kv_channels: int = None
|
|
|
|
cfm_config: CfmConfig
|
|
|
|
|
|
class VoxCPMConfig(BaseModel):
|
|
lm_config: MiniCPM4Config
|
|
patch_size: int = 2
|
|
feat_dim: int = 64
|
|
residual_lm_num_layers: int = 6
|
|
scalar_quantization_latent_dim: int = 256
|
|
scalar_quantization_scale: int = 9
|
|
|
|
encoder_config: VoxCPMEncoderConfig
|
|
dit_config: VoxCPMDitConfig
|
|
|
|
max_length: int = 4096
|
|
device: str = "cuda"
|
|
dtype: str = "bfloat16"
|
|
|
|
|
|
class VoxCPMModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: VoxCPMConfig,
|
|
tokenizer: LlamaTokenizerFast,
|
|
audio_vae: AudioVAE,
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.feat_dim = config.feat_dim
|
|
self.patch_size = config.patch_size
|
|
self.device = config.device
|
|
if not torch.cuda.is_available():
|
|
self.device = "cpu"
|
|
|
|
# Text-Semantic LM
|
|
self.base_lm = MiniCPMModel(config.lm_config)
|
|
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
|
|
|
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
|
self.audio_start_token = 101
|
|
self.audio_end_token = 102
|
|
|
|
# Residual Acoustic LM
|
|
residual_lm_config = config.lm_config.model_copy(deep=True)
|
|
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
|
residual_lm_config.vocab_size = 0
|
|
self.residual_lm = MiniCPMModel(residual_lm_config)
|
|
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
|
|
|
# Local Encoder
|
|
encoder_config = config.lm_config.model_copy(deep=True)
|
|
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
|
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
|
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
|
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
|
encoder_config.kv_channels = config.encoder_config.kv_channels
|
|
encoder_config.vocab_size = 0
|
|
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
|
|
|
# Local DiT
|
|
decoder_config = config.lm_config.model_copy(deep=True)
|
|
decoder_config.hidden_size = config.dit_config.hidden_dim
|
|
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
|
decoder_config.num_attention_heads = config.dit_config.num_heads
|
|
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
|
decoder_config.kv_channels = config.dit_config.kv_channels
|
|
decoder_config.vocab_size = 0
|
|
self.feat_decoder = UnifiedCFM(
|
|
in_channels=config.feat_dim,
|
|
cfm_params=config.dit_config.cfm_config,
|
|
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
|
)
|
|
|
|
# Projection layers
|
|
self.fsq_layer = ScalarQuantizationLayer(
|
|
config.lm_config.hidden_size,
|
|
config.lm_config.hidden_size,
|
|
config.scalar_quantization_latent_dim,
|
|
config.scalar_quantization_scale
|
|
)
|
|
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
|
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
|
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
|
|
|
# Stop Predictor
|
|
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
|
self.stop_actn = nn.SiLU()
|
|
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
|
|
|
# Audio VAE
|
|
self.audio_vae = audio_vae
|
|
self.chunk_size = audio_vae.chunk_size
|
|
self.sample_rate = audio_vae.sample_rate
|
|
|
|
|
|
def optimize(self):
|
|
try:
|
|
if self.device != "cuda":
|
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
|
try:
|
|
import triton
|
|
except:
|
|
raise ValueError("triton is not installed")
|
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
|
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
|
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
|
except Exception as e:
|
|
print(e)
|
|
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
|
self.base_lm.forward_step = self.base_lm.forward_step
|
|
self.residual_lm.forward_step = self.residual_lm.forward_step
|
|
self.feat_encoder_step = self.feat_encoder
|
|
self.feat_decoder.estimator = self.feat_decoder.estimator
|
|
return self
|
|
|
|
|
|
@torch.inference_mode()
|
|
def generate(
|
|
self,
|
|
target_text: str,
|
|
prompt_text: str = "",
|
|
prompt_wav_path: str = "",
|
|
min_len: int = 2,
|
|
max_len: int = 2000,
|
|
inference_timesteps: int = 10,
|
|
cfg_value: float = 2.0,
|
|
retry_badcase: bool = False,
|
|
retry_badcase_max_times: int = 3,
|
|
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
|
):
|
|
if len(prompt_wav_path) == 0:
|
|
text = target_text
|
|
text_token = torch.LongTensor(self.text_tokenizer(text))
|
|
text_token = torch.cat(
|
|
[
|
|
text_token,
|
|
torch.tensor(
|
|
[self.audio_start_token],
|
|
dtype=torch.int32,
|
|
device=text_token.device,
|
|
),
|
|
],
|
|
dim=-1,
|
|
)
|
|
text_length = text_token.shape[0]
|
|
|
|
audio_feat = torch.zeros(
|
|
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
|
dtype=torch.float32,
|
|
device=text_token.device,
|
|
)
|
|
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
|
|
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
|
|
|
|
else:
|
|
text = prompt_text + target_text
|
|
text_token = torch.LongTensor(self.text_tokenizer(text))
|
|
text_token = torch.cat(
|
|
[
|
|
text_token,
|
|
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
|
],
|
|
dim=-1,
|
|
)
|
|
text_length = text_token.shape[0]
|
|
|
|
audio, sr = torchaudio.load(prompt_wav_path)
|
|
if audio.size(0) > 1:
|
|
audio = audio.mean(dim=0, keepdim=True)
|
|
|
|
if sr != self.sample_rate:
|
|
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
|
|
|
patch_len = self.patch_size * self.chunk_size
|
|
|
|
if audio.size(1) % patch_len != 0:
|
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
|
|
|
# (B, D, T)
|
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
|
|
|
audio_feat = audio_feat.view(
|
|
self.audio_vae.latent_dim,
|
|
-1,
|
|
self.patch_size,
|
|
).permute(1, 2, 0)
|
|
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
|
audio_length = audio_feat.size(0)
|
|
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
|
text_token = torch.cat([text_token, text_pad_token])
|
|
audio_pad_feat = torch.zeros(
|
|
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
|
dtype=torch.float32,
|
|
device=text_token.device,
|
|
)
|
|
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
|
|
text_mask = (
|
|
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
|
)
|
|
audio_mask = (
|
|
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
|
)
|
|
|
|
text_token = text_token.unsqueeze(0).to(self.device)
|
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
|
|
|
target_text_length = len(self.text_tokenizer(target_text))
|
|
|
|
retry_badcase_times = 0
|
|
while retry_badcase_times < retry_badcase_max_times:
|
|
latent_pred, pred_audio_feat = self.inference(
|
|
text_token,
|
|
text_mask,
|
|
audio_feat,
|
|
audio_mask,
|
|
min_len=min_len,
|
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
|
inference_timesteps=inference_timesteps,
|
|
cfg_value=cfg_value,
|
|
)
|
|
if retry_badcase:
|
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
|
retry_badcase_times += 1
|
|
continue
|
|
else:
|
|
break
|
|
else:
|
|
break
|
|
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
|
|
|
@torch.inference_mode()
|
|
def build_prompt_cache(
|
|
self,
|
|
prompt_text: str,
|
|
prompt_wav_path: str,
|
|
):
|
|
"""
|
|
Build prompt cache for subsequent fast generation.
|
|
|
|
Args:
|
|
prompt_text: prompt text (required)
|
|
prompt_wav_path: prompt audio path (required)
|
|
|
|
Returns:
|
|
prompt_cache: dict with text tokens and audio features
|
|
"""
|
|
if not prompt_text or not prompt_wav_path:
|
|
raise ValueError("prompt_text and prompt_wav_path are required")
|
|
|
|
# build text tokens
|
|
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
|
|
|
# load audio
|
|
audio, sr = torchaudio.load(prompt_wav_path)
|
|
if audio.size(0) > 1:
|
|
audio = audio.mean(dim=0, keepdim=True)
|
|
|
|
if sr != self.sample_rate:
|
|
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
|
|
|
patch_len = self.patch_size * self.chunk_size
|
|
|
|
if audio.size(1) % patch_len != 0:
|
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
|
|
|
# extract audio features
|
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
|
|
|
audio_feat = audio_feat.view(
|
|
self.audio_vae.latent_dim,
|
|
-1,
|
|
self.patch_size,
|
|
).permute(1, 2, 0) # (D, T, P)
|
|
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
|
# build prompt cache
|
|
prompt_cache = {
|
|
"text_token": text_token,
|
|
"audio_feat": audio_feat,
|
|
}
|
|
|
|
return prompt_cache
|
|
|
|
|
|
def merge_prompt_cache(
|
|
self,
|
|
original_cache: dict,
|
|
new_text_token: torch.Tensor,
|
|
new_audio_feat: torch.Tensor,
|
|
):
|
|
"""
|
|
Merge original prompt cache with newly generated content to stabilize voice.
|
|
|
|
Args:
|
|
original_cache: original prompt cache
|
|
new_text_token: newly generated text tokens
|
|
new_audio_feat: newly generated audio features
|
|
|
|
Returns:
|
|
merged_cache: merged cache
|
|
"""
|
|
if original_cache is None:
|
|
return {
|
|
"text_token": new_text_token,
|
|
"audio_feat": new_audio_feat,
|
|
}
|
|
original_text_token = original_cache["text_token"]
|
|
original_audio_feat = original_cache["audio_feat"]
|
|
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
|
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
|
|
|
# build new cache
|
|
merged_cache = {
|
|
"text_token": merged_text_token,
|
|
"audio_feat": merged_audio_feat,
|
|
}
|
|
|
|
return merged_cache
|
|
|
|
@torch.inference_mode()
|
|
def generate_with_prompt_cache(
|
|
self,
|
|
target_text: str,
|
|
prompt_cache: dict,
|
|
min_len: int = 2,
|
|
max_len: int = 2000,
|
|
inference_timesteps: int = 10,
|
|
cfg_value: float = 2.0,
|
|
retry_badcase: bool = False,
|
|
retry_badcase_max_times: int = 3,
|
|
retry_badcase_ratio_threshold: float = 6.0,
|
|
):
|
|
"""
|
|
Generate audio using pre-built prompt cache.
|
|
|
|
Args:
|
|
target_text: Text to convert to speech
|
|
prompt_cache: Cache built by build_prompt_cache (can be None)
|
|
min_len: Minimum audio length to avoid very short audio
|
|
max_len: Maximum audio length
|
|
inference_timesteps: Number of diffusion sampling steps
|
|
cfg_value: Classifier-free guidance value
|
|
retry_badcase: Whether to retry on bad cases
|
|
retry_badcase_max_times: Maximum retry attempts
|
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
|
|
|
Returns:
|
|
tuple: (decoded audio tensor, new text tokens, new audio features)
|
|
"""
|
|
# get prompt from cache
|
|
if prompt_cache is None:
|
|
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
|
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
|
else:
|
|
prompt_text_token = prompt_cache["text_token"]
|
|
prompt_audio_feat = prompt_cache["audio_feat"]
|
|
# build target text tokens
|
|
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
|
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
|
text_token = torch.cat(
|
|
[
|
|
text_token,
|
|
torch.tensor(
|
|
[self.audio_start_token],
|
|
dtype=torch.int32,
|
|
device=text_token.device,
|
|
),
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
audio_length = prompt_audio_feat.size(0)
|
|
text_length = text_token.shape[0]
|
|
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
|
audio_pad_feat = torch.zeros(
|
|
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
|
|
dtype=torch.float32,
|
|
device=text_token.device,
|
|
)
|
|
text_token = torch.cat([text_token, text_pad_token])
|
|
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
|
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
|
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
|
|
|
text_token = text_token.unsqueeze(0).to(self.device)
|
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
|
|
|
# run inference
|
|
target_text_length = len(self.text_tokenizer(target_text))
|
|
retry_badcase_times = 0
|
|
while retry_badcase_times < retry_badcase_max_times:
|
|
latent_pred, pred_audio_feat = self.inference(
|
|
text_token,
|
|
text_mask,
|
|
audio_feat,
|
|
audio_mask,
|
|
min_len=min_len,
|
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
|
inference_timesteps=inference_timesteps,
|
|
cfg_value=cfg_value,
|
|
)
|
|
if retry_badcase:
|
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
|
retry_badcase_times += 1
|
|
continue
|
|
else:
|
|
break
|
|
else:
|
|
break
|
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
|
|
|
return (
|
|
decode_audio,
|
|
target_text_token,
|
|
pred_audio_feat
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def inference(
|
|
self,
|
|
text: torch.Tensor,
|
|
text_mask: torch.Tensor,
|
|
feat: torch.Tensor,
|
|
feat_mask: torch.Tensor,
|
|
min_len: int = 2,
|
|
max_len: int = 2000,
|
|
inference_timesteps: int = 10,
|
|
cfg_value: float = 2.0,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Core inference method for audio generation.
|
|
|
|
This is the main inference loop that generates audio features
|
|
using the language model and diffusion transformer.
|
|
|
|
Args:
|
|
text: Input text tokens
|
|
text_mask: Mask for text tokens
|
|
feat: Input audio features
|
|
feat_mask: Mask for audio features
|
|
min_len: Minimum generation length
|
|
max_len: Maximum generation length
|
|
inference_timesteps: Number of diffusion steps
|
|
cfg_value: Classifier-free guidance value
|
|
|
|
Returns:
|
|
Tuple containing:
|
|
- Predicted latent features
|
|
- Predicted audio feature sequence
|
|
"""
|
|
B, T, P, D = feat.shape
|
|
|
|
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
|
feat_embed = self.enc_to_lm_proj(feat_embed)
|
|
|
|
if self.config.lm_config.use_mup:
|
|
scale_emb = self.config.lm_config.scale_emb
|
|
else:
|
|
scale_emb = 1.0
|
|
|
|
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
|
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
|
|
|
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
|
pred_feat_seq = [] # b, t, p, d
|
|
curr_embed = None
|
|
|
|
enc_outputs, kv_cache_tuple = self.base_lm(
|
|
inputs_embeds=combined_embed,
|
|
is_causal=True,
|
|
)
|
|
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
|
|
|
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
|
lm_hidden = enc_outputs[:, -1, :]
|
|
|
|
|
|
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
|
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
|
is_causal=True,
|
|
)
|
|
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
|
residual_hidden = residual_enc_outputs[:, -1, :]
|
|
|
|
|
|
for i in tqdm(range(max_len)):
|
|
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
|
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
|
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
|
|
|
|
pred_feat = self.feat_decoder(
|
|
mu=dit_hidden,
|
|
patch_size=self.patch_size,
|
|
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
|
n_timesteps=inference_timesteps,
|
|
cfg_value=cfg_value,
|
|
).transpose(
|
|
1, 2
|
|
) # [b, p, d]
|
|
|
|
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
|
curr_embed = self.enc_to_lm_proj(curr_embed)
|
|
|
|
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
|
prefix_feat_cond = pred_feat
|
|
|
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
|
if i > min_len and stop_flag == 1:
|
|
break
|
|
|
|
lm_hidden = self.base_lm.forward_step(
|
|
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
|
).clone()
|
|
|
|
|
|
lm_hidden = self.fsq_layer(lm_hidden)
|
|
residual_hidden = self.residual_lm.forward_step(
|
|
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
|
).clone()
|
|
|
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
|
|
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
|
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
|
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
|
|
|
@classmethod
|
|
def from_local(cls, path: str):
|
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
|
|
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
|
|
|
audio_vae = AudioVAE()
|
|
vae_state_dict = torch.load(
|
|
os.path.join(path, "audiovae.pth"),
|
|
map_location="cpu",
|
|
weights_only=True,
|
|
)["state_dict"]
|
|
|
|
model = cls(config, tokenizer, audio_vae)
|
|
lm_dtype = get_dtype(config.dtype)
|
|
model = model.to(lm_dtype)
|
|
model.audio_vae = model.audio_vae.to(torch.float32)
|
|
|
|
model_state_dict = torch.load(
|
|
os.path.join(path, "pytorch_model.bin"),
|
|
map_location="cpu",
|
|
weights_only=True,
|
|
)["state_dict"]
|
|
|
|
for kw, val in vae_state_dict.items():
|
|
model_state_dict[f"audio_vae.{kw}"] = val
|
|
model.load_state_dict(model_state_dict, strict=True)
|
|
return model.to(model.device).eval().optimize()
|