This commit is contained in:
zengguoyang
2025-09-16 11:46:47 +08:00
commit 272b8ffbf6
31 changed files with 3473 additions and 0 deletions

View File

View File

@@ -0,0 +1 @@
from .audio_vae import AudioVAE

View File

@@ -0,0 +1,359 @@
import math
from typing import List, Union
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(
input_channel,
input_channel,
kernel_size=7,
padding=3,
groups=input_channel,
),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 5, 8, 8],
latent_dim: int = 64,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 5, 2],
depthwise: bool = True,
sample_rate: int = 16000,
use_noise_block: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
)
self.sample_rate = sample_rate
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return self.decoder(z)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]

View File

@@ -0,0 +1 @@
from .scalar_quantization_layer import ScalarQuantizationLayer

View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class ScalarQuantizationLayer(nn.Module):
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.scale = scale
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
if self.training:
quantized = torch.round(hidden * self.scale) / self.scale
hidden = hidden + (quantized - hidden).detach()
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden)

View File

@@ -0,0 +1,2 @@
from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT

View File

@@ -0,0 +1,114 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()

View File

@@ -0,0 +1,137 @@
import torch
from typing import List
from .local_dit import VoxCPMLocDiT
import math
from pydantic import BaseModel
class CfmConfig(BaseModel):
sigma_min: float = 1e-06
solver: str = "euler"
t_scheduler: str = "log-norm"
class UnifiedCFM(torch.nn.Module):
def __init__(
self,
in_channels,
cfm_params: CfmConfig,
estimator: VoxCPMLocDiT,
mean_mode: bool = False,
):
super().__init__()
self.solver = cfm_params.solver
self.sigma_min = cfm_params.sigma_min
self.t_scheduler = cfm_params.t_scheduler
self.in_channels = in_channels
self.mean_mode = mean_mode
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(
self,
mu: torch.Tensor,
n_timesteps: int,
patch_size: int,
cond: torch.Tensor,
temperature: float = 1.0,
cfg_value: float = 1.0,
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
n_timesteps (int): number of diffusion steps
cond: Not used but kept for future purposes
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
b, c = mu.shape
t = patch_size
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# Sway sampling strategy
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
def optimized_scale(self, positive_flat, negative_flat):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
st_star = dot_product / squared_norm
return st_star
def solve_euler(
self,
x: torch.Tensor,
t_span: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor,
cfg_value: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
cond: Not used but kept for future purposes
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
"""
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
sol = []
zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = 0.
else:
# Classifier-Free Guidance inference introduced in VoiceBox
b = x.size(0)
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
x_in[:b], x_in[b:] = x, x
mu_in[:b] = mu
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
# not used now
if not self.mean_mode:
dt_in = torch.zeros_like(dt_in)
cond_in[:b], cond_in[b:] = cond, cond
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
if use_cfg_zero_star:
positive_flat = dphi_dt.view(b, -1)
negative_flat = cfg_dphi_dt.view(b, -1)
st_star = self.optimized_scale(positive_flat, negative_flat)
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
else:
st_star = 1.0
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
x = x - dt * dphi_dt
t = t - dt
sol.append(x)
if step < len(t_span) - 1:
dt = t - t_span[step + 1]
return sol[-1]

View File

@@ -0,0 +1 @@
from .local_encoder import VoxCPMLocEnc

View File

@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
from einops import rearrange
class VoxCPMLocEnc(nn.Module):
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
super().__init__()
self.config = config
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
self.encoder = MiniCPMModel(config)
def forward(self, x):
"""
x: [B, T, P, D]
"""
B, T, P, D = x.shape
x = self.in_proj(x)
special_tokens = self.special_token.expand(B, T, 1, -1)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
outputs, _ = self.encoder(x, is_causal=False)
cls_output = outputs[:, 0, :]
return rearrange(cls_output, "(b t) c -> b t c", b=B)

View File

@@ -0,0 +1,3 @@
from .config import MiniCPM4Config
from .model import MiniCPMModel
from .cache import StaticKVCache

View File

@@ -0,0 +1,47 @@
from typing import List, Tuple
import torch
class StaticKVCache:
def __init__(
self,
num_layers: int,
num_kv_heads: int,
dim_kv_head: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
max_length: int = 8192,
):
self.max_length = max_length
self.num_layers = num_layers
self.kv_cache = torch.zeros(
2,
num_layers,
batch_size,
num_kv_heads,
max_length,
dim_kv_head,
device=device,
dtype=dtype,
)
self.current_length = 0
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
def step(self) -> int:
if self.current_length >= self.max_length:
raise ValueError("KV cache is full")
ret = self.current_length
self.current_length += 1
return ret
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
self.current_length = kv_caches[0][0].size(2)
self.kv_cache.zero_()
for i in range(self.num_layers):
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]

View File

@@ -0,0 +1,29 @@
from pydantic import BaseModel
from typing import List
class RopeScalingConfig(BaseModel):
type: str
long_factor: List[float]
short_factor: List[float]
original_max_position_embeddings: int
class MiniCPM4Config(BaseModel):
bos_token_id: int
eos_token_id: int
hidden_size: int
intermediate_size: int
max_position_embeddings: int
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_scaling: RopeScalingConfig
vocab_size: int
use_mup: bool = True
scale_emb: float
dim_model_base: int
scale_depth: float
rope_theta: float
kv_channels: int = None

View File

@@ -0,0 +1,411 @@
from .config import MiniCPM4Config
import torch
import torch.nn as nn
from typing import List, Tuple
import math
from .cache import StaticKVCache
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
old_dtype = hidden.dtype
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
return hidden * weight
class MiniCPMRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniCPMRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""
Args:
q: Tensor(batch_size, num_heads, seq_len, head_dim)
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
cos: Tensor(seq_len, head_dim)
sin: Tensor(seq_len, head_dim)
Returns:
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
"""
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class MiniCPMLongRoPE(nn.Module):
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, config: MiniCPM4Config):
super().__init__()
self.config = config
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
self.base = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.short_factor = config.rope_scaling.short_factor
self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
self.scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = 0
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存"""
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device),
self.inv_freq.to(device=device).to(dtype)
)
# 创建embeddings
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
Returns:
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
"""
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return cos, sin
class MiniCPMAttention(nn.Module):
def __init__(self, config: MiniCPM4Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = 10000.0
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
is_causal: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
is_causal=is_causal,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
past_key_value = (key_states, value_states)
return attn_output, past_key_value
def forward_step(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
position_id: int,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_cache, value_cache = kv_cache
key_cache[:, :, position_id, :] = key_states
value_cache[:, :, position_id, :] = value_states
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_cache,
value_cache,
attn_mask=attn_mask,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class MiniCPMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MiniCPMDecoderLayer(nn.Module):
def __init__(self, config: MiniCPM4Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
self.mlp = MiniCPMMLP(config)
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.scale_depth = config.scale_depth
self.num_hidden_layers = config.num_hidden_layers
self.use_mup = config.use_mup
def forward(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
is_causal: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
is_causal (`bool`): whether the attention mask is causal
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, present_key_value = self.self_attn(
hidden_states=hidden_states,
position_emb=position_emb,
is_causal=is_causal,
)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
return hidden_states, present_key_value
def forward_step(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
position_id: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn.forward_step(
hidden_states=hidden_states,
position_emb=position_emb,
position_id=position_id,
kv_cache=kv_cache,
)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
return hidden_states
class MiniCPMModel(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
Args:
config: MiniCPMConfig
"""
def __init__(self, config: MiniCPM4Config):
super().__init__()
self.vocab_size = config.vocab_size
self.config = config
if config.vocab_size > 0:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
else:
self.embed_tokens = nn.Identity()
self.layers = nn.ModuleList(
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rope_emb = MiniCPMLongRoPE(config)
self.kv_cache = None
def forward(
self,
inputs_embeds: torch.Tensor,
is_causal: bool = True,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
is_causal: bool, whether the attention mask is causal
Returns:
hidden_states: Tensor(batch_size, seq_length, hidden_size)
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
"""
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
position_emb = self.rope_emb(position_ids)
hidden_states = inputs_embeds
next_decoder_cache = []
for decoder_layer in self.layers:
hidden_states, this_cache = decoder_layer(
hidden_states,
position_emb,
is_causal,
)
next_decoder_cache.append(this_cache)
hidden_states = self.norm(hidden_states)
return hidden_states, next_decoder_cache
def forward_step(
self,
inputs_embeds: torch.Tensor,
position_id: torch.Tensor,
) -> torch.Tensor:
"""
Args:
inputs_embeds: Tensor(batch_size, hidden_size)
Returns:
hidden_states: Tensor(batch_size, hidden_size)
"""
assert self.kv_cache is not None, "KV cache is not setup"
position_emb = self.rope_emb(position_id)
hidden_states = inputs_embeds
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer.forward_step(
hidden_states,
position_emb,
position_id,
self.kv_cache.get_layer_cache(i),
)
hidden_states = self.norm(hidden_states)
return hidden_states
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
batch_size=batch_size,
device=device,
dtype=dtype,
max_length=max_length,
)