init
This commit is contained in:
605
src/voxcpm/model/voxcpm.py
Normal file
605
src/voxcpm/model/voxcpm.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""
|
||||
VoxCPM: A Tokenizer-free speech generation model
|
||||
|
||||
This module contains the main VoxCPM model implementation, including configuration classes
|
||||
and the core VoxCPMModel for text-to-speech generation.
|
||||
|
||||
Copyright 2025 OpenBMB
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
|
||||
class VoxCPMDitConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
cfm_config: CfmConfig
|
||||
|
||||
|
||||
class VoxCPMConfig(BaseModel):
|
||||
lm_config: MiniCPM4Config
|
||||
patch_size: int = 2
|
||||
feat_dim: int = 64
|
||||
residual_lm_num_layers: int = 6
|
||||
scalar_quantization_latent_dim: int = 256
|
||||
scalar_quantization_scale: int = 9
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
self.audio_end_token = 102
|
||||
|
||||
# Residual Acoustic LM
|
||||
residual_lm_config = config.lm_config.model_copy(deep=True)
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
||||
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
||||
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
||||
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
||||
encoder_config.kv_channels = config.encoder_config.kv_channels
|
||||
encoder_config.vocab_size = 0
|
||||
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
||||
|
||||
# Local DiT
|
||||
decoder_config = config.lm_config.model_copy(deep=True)
|
||||
decoder_config.hidden_size = config.dit_config.hidden_dim
|
||||
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
||||
decoder_config.num_attention_heads = config.dit_config.num_heads
|
||||
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
||||
decoder_config.kv_channels = config.dit_config.kv_channels
|
||||
decoder_config.vocab_size = 0
|
||||
self.feat_decoder = UnifiedCFM(
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
self.fsq_layer = ScalarQuantizationLayer(
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
|
||||
# Stop Predictor
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
if self.device == "cuda":
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
return self
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
prompt_wav_path: str = "",
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
|
||||
|
||||
else:
|
||||
text = prompt_text + target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
self,
|
||||
prompt_text: str,
|
||||
prompt_wav_path: str,
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent fast generation.
|
||||
|
||||
Args:
|
||||
prompt_text: prompt text (required)
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Merge original prompt cache with newly generated content to stabilize voice.
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
Args:
|
||||
target_text: Text to convert to speech
|
||||
prompt_cache: Cache built by build_prompt_cache (can be None)
|
||||
min_len: Minimum audio length to avoid very short audio
|
||||
max_len: Maximum audio length
|
||||
inference_timesteps: Number of diffusion sampling steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
"""
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
feat: torch.Tensor,
|
||||
feat_mask: torch.Tensor,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
using the language model and diffusion transformer.
|
||||
|
||||
Args:
|
||||
text: Input text tokens
|
||||
text_mask: Mask for text tokens
|
||||
feat: Input audio features
|
||||
feat_mask: Mask for audio features
|
||||
min_len: Minimum generation length
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
scale_emb = self.config.lm_config.scale_emb
|
||||
else:
|
||||
scale_emb = 1.0
|
||||
|
||||
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
||||
pred_feat_seq = [] # b, t, p, d
|
||||
curr_embed = None
|
||||
|
||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||
inputs_embeds=combined_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
||||
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
|
||||
|
||||
pred_feat = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
||||
n_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
).transpose(
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
|
||||
lm_hidden = self.base_lm.forward_step(
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
Reference in New Issue
Block a user