mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-14 04:18:12 +00:00
init
This commit is contained in:
0
src/voxcpm/modules/__init__.py
Normal file
0
src/voxcpm/modules/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(
|
||||
input_channel,
|
||||
input_channel,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=input_channel,
|
||||
),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
1
src/voxcpm/modules/layers/__init__.py
Normal file
1
src/voxcpm/modules/layers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ScalarQuantizationLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.scale = scale
|
||||
|
||||
self.in_proj = nn.Linear(in_dim, latent_dim)
|
||||
self.out_proj = nn.Linear(latent_dim, out_dim)
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.in_proj(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
|
||||
if self.training:
|
||||
quantized = torch.round(hidden * self.scale) / self.scale
|
||||
hidden = hidden + (quantized - hidden).detach()
|
||||
else:
|
||||
hidden = torch.round(hidden * self.scale) / self.scale
|
||||
|
||||
return self.out_proj(hidden)
|
||||
2
src/voxcpm/modules/locdit/__init__.py
Normal file
2
src/voxcpm/modules/locdit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
n_timesteps: int,
|
||||
patch_size: int,
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
def solve_euler(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t_span: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: Not used but kept for future purposes
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
|
||||
# not used now
|
||||
if not self.mean_mode:
|
||||
dt_in = torch.zeros_like(dt_in)
|
||||
cond_in[:b], cond_in[b:] = cond, cond
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
st_star = self.optimized_scale(positive_flat, negative_flat)
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
t = t - dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
1
src/voxcpm/modules/locenc/__init__.py
Normal file
1
src/voxcpm/modules/locenc/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .local_encoder import VoxCPMLocEnc
|
||||
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VoxCPMLocEnc(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
||||
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
||||
self.encoder = MiniCPMModel(config)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, T, P, D]
|
||||
"""
|
||||
B, T, P, D = x.shape
|
||||
|
||||
x = self.in_proj(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
outputs, _ = self.encoder(x, is_causal=False)
|
||||
cls_output = outputs[:, 0, :]
|
||||
|
||||
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .config import MiniCPM4Config
|
||||
from .model import MiniCPMModel
|
||||
from .cache import StaticKVCache
|
||||
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
class StaticKVCache:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
dim_kv_head: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_length: int = 8192,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
2,
|
||||
num_layers,
|
||||
batch_size,
|
||||
num_kv_heads,
|
||||
max_length,
|
||||
dim_kv_head,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.current_length = 0
|
||||
|
||||
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
|
||||
|
||||
def step(self) -> int:
|
||||
if self.current_length >= self.max_length:
|
||||
raise ValueError("KV cache is full")
|
||||
|
||||
ret = self.current_length
|
||||
self.current_length += 1
|
||||
return ret
|
||||
|
||||
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
self.current_length = kv_caches[0][0].size(2)
|
||||
self.kv_cache.zero_()
|
||||
for i in range(self.num_layers):
|
||||
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
|
||||
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
|
||||
29
src/voxcpm/modules/minicpm4/config.py
Normal file
29
src/voxcpm/modules/minicpm4/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class RopeScalingConfig(BaseModel):
|
||||
type: str
|
||||
long_factor: List[float]
|
||||
short_factor: List[float]
|
||||
original_max_position_embeddings: int
|
||||
|
||||
|
||||
class MiniCPM4Config(BaseModel):
|
||||
bos_token_id: int
|
||||
eos_token_id: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
max_position_embeddings: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_scaling: RopeScalingConfig
|
||||
vocab_size: int
|
||||
use_mup: bool = True
|
||||
scale_emb: float
|
||||
dim_model_base: int
|
||||
scale_depth: float
|
||||
rope_theta: float
|
||||
kv_channels: int = None
|
||||
411
src/voxcpm/modules/minicpm4/model.py
Normal file
411
src/voxcpm/modules/minicpm4/model.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from .config import MiniCPM4Config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple
|
||||
import math
|
||||
from .cache import StaticKVCache
|
||||
|
||||
|
||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||
old_dtype = hidden.dtype
|
||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
||||
return hidden * weight
|
||||
|
||||
|
||||
class MiniCPMRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
MiniCPMRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
q: Tensor(batch_size, num_heads, seq_len, head_dim)
|
||||
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
cos: Tensor(seq_len, head_dim)
|
||||
sin: Tensor(seq_len, head_dim)
|
||||
Returns:
|
||||
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
"""
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
||||
|
||||
|
||||
class MiniCPMLongRoPE(nn.Module):
|
||||
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
|
||||
self.base = config.rope_theta
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.short_factor = config.rope_scaling.short_factor
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.max_seq_len_cached = 0
|
||||
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
|
||||
else:
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
|
||||
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
|
||||
|
||||
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
|
||||
Returns:
|
||||
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
|
||||
"""
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
class MiniCPMAttention(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = 10000.0
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
return attn_output, past_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: int,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
bsz, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask=attn_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MiniCPMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MiniCPMMLP(config)
|
||||
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.scale_depth = config.scale_depth
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.use_mup = config.use_mup
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
|
||||
is_causal (`bool`): whether the attention mask is causal
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, present_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn.forward_step(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
position_id=position_id,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: MiniCPMConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
if config.vocab_size > 0:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
else:
|
||||
self.embed_tokens = nn.Identity()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rope_emb = MiniCPMLongRoPE(config)
|
||||
|
||||
self.kv_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_causal: bool = True,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
|
||||
is_causal: bool, whether the attention mask is causal
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, seq_length, hidden_size)
|
||||
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
|
||||
"""
|
||||
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
|
||||
position_emb = self.rope_emb(position_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
next_decoder_cache = []
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
|
||||
hidden_states, this_cache = decoder_layer(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
is_causal,
|
||||
)
|
||||
next_decoder_cache.append(this_cache)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states, next_decoder_cache
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_id: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, hidden_size)
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, hidden_size)
|
||||
"""
|
||||
assert self.kv_cache is not None, "KV cache is not setup"
|
||||
|
||||
position_emb = self.rope_emb(position_id)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer.forward_step(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
position_id,
|
||||
self.kv_cache.get_layer_cache(i),
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_length=max_length,
|
||||
)
|
||||
Reference in New Issue
Block a user