This commit is contained in:
zengguoyang
2025-09-16 11:46:47 +08:00
commit 272b8ffbf6
31 changed files with 3473 additions and 0 deletions

5
src/voxcpm/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .core import VoxCPM
__all__ = [
"VoxCPM",
]

292
src/voxcpm/cli.py Normal file
View File

@@ -0,0 +1,292 @@
#!/usr/bin/env python3
"""
VoxCPM Command Line Interface
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
Usage examples:
# Direct synthesis (single sample)
voxcpm --text "Hello world" --output output.wav
# Voice cloning (with reference audio and text)
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
# Batch processing (each line in the file is one sample)
voxcpm --input texts.txt --output-dir ./outputs/
"""
import argparse
import os
import sys
from pathlib import Path
from typing import Optional, List
import soundfile as sf
from voxcpm.core import VoxCPM
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
"""Validate that a file exists."""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
return path
def validate_output_path(output_path: str) -> Path:
"""Validate the output path and create parent directories if needed."""
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
return path
def load_model(args) -> VoxCPM:
"""Load VoxCPM model.
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
"""
print("Loading VoxCPM model...")
# 兼容旧参数ZIPENHANCER_MODEL_PATH 环境变量作为默认
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
"ZIPENHANCER_MODEL_PATH", None
)
# Load from local path if provided
if getattr(args, "model_path", None):
try:
model = VoxCPM(
voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False),
)
print("Model loaded (local).")
return model
except Exception as e:
print(f"Failed to load model (local): {e}")
sys.exit(1)
# Otherwise, try from_pretrained (Hub); exit on failure
try:
model = VoxCPM.from_pretrained(
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
load_denoiser=not getattr(args, "no_denoiser", False),
zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None),
local_files_only=getattr(args, "local_files_only", False),
)
print("Model loaded (from_pretrained).")
return model
except Exception as e:
print(f"Failed to load model (from_pretrained): {e}")
sys.exit(1)
def cmd_clone(args):
"""Voice cloning command."""
# Validate inputs
if not args.text:
print("Error: Please provide text to synthesize (--text)")
sys.exit(1)
if not args.prompt_audio:
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
sys.exit(1)
if not args.prompt_text:
print("Error: Voice cloning requires a reference text (--prompt-text)")
sys.exit(1)
# Validate files
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
output_path = validate_output_path(args.output)
# Load model
model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}")
print(f"Reference audio: {prompt_audio_path}")
print(f"Reference text: {args.prompt_text}")
audio_array = model.generate(
text=args.text,
prompt_wav_path=str(prompt_audio_path),
prompt_text=args.prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=args.denoise
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
print(f"Duration: {duration:.2f}s")
def cmd_synthesize(args):
"""Direct TTS synthesis command."""
# Validate inputs
if not args.text:
print("Error: Please provide text to synthesize (--text)")
sys.exit(1)
# Validate output path
output_path = validate_output_path(args.output)
# Load model
model = load_model(args)
# Generate audio
print(f"Synthesizing text: {args.text}")
audio_array = model.generate(
text=args.text,
prompt_wav_path=None,
prompt_text=None,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=False # 无参考音频时不需要降噪
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
print(f"Duration: {duration:.2f}s")
def cmd_batch(args):
"""Batch synthesis command."""
# Validate input file
input_file = validate_file_exists(args.input, "input file")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
try:
with open(input_file, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
except Exception as e:
print(f"Failed to read input file: {e}")
sys.exit(1)
if not texts:
print("Error: Input file is empty or contains no valid lines")
sys.exit(1)
print(f"Found {len(texts)} lines to process")
model = load_model(args)
prompt_audio_path = None
if args.prompt_audio:
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
success_count = 0
for i, text in enumerate(texts, 1):
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
try:
audio_array = model.generate(
text=text,
prompt_wav_path=prompt_audio_path,
prompt_text=args.prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
normalize=args.normalize,
denoise=args.denoise and prompt_audio_path is not None
)
output_file = output_dir / f"output_{i:03d}.wav"
sf.write(str(output_file), audio_array, 16000)
duration = len(audio_array) / 16000
print(f" Saved: {output_file} ({duration:.2f}s)")
success_count += 1
except Exception as e:
print(f" Failed: {e}")
continue
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
def _build_unified_parser():
"""Build unified argument parser (no subcommands, route by args)."""
parser = argparse.ArgumentParser(
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Direct synthesis (single sample)
voxcpm --text "Hello world" --output out.wav
# Voice cloning (reference audio + text)
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
# Batch processing
voxcpm --input texts.txt --output-dir ./outs
# Select model (from Hub)
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
"""
)
# Task selection (automatic routing by presence of args)
parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
# Prompt audio (for voice cloning)
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
# Generation parameters
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
# Model loading parameters
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
return parser
def main():
"""Unified CLI entrypoint: route by provided arguments."""
parser = _build_unified_parser()
args = parser.parse_args()
# Routing: prefer batch → single (clone/direct)
if args.input:
if not args.output_dir:
print("Error: Batch mode requires --output-dir")
parser.print_help()
sys.exit(1)
return cmd_batch(args)
# Single-sample mode
if not args.text or not args.output:
print("Error: Single-sample mode requires --text and --output")
parser.print_help()
sys.exit(1)
# If prompt audio+text provided → voice cloning
if args.prompt_audio or args.prompt_text:
if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
sys.exit(1)
return cmd_clone(args)
# Otherwise → direct synthesis
return cmd_synthesize(args)
if __name__ == "__main__":
main()

181
src/voxcpm/core.py Normal file
View File

@@ -0,0 +1,181 @@
import torch
import torchaudio
import os
import tempfile
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel
from .utils.text_normalize import TextNormalizer
class VoxCPM:
def __init__(self,
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
):
"""Initialize VoxCPM TTS pipeline.
Args:
voxcpm_model_path: Local filesystem path to the VoxCPM model assets
(weights, configs, etc.). Typically the directory returned by
a prior download step.
zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
self.text_normalizer = TextNormalizer()
if enable_denoiser and zipenhancer_model_path is not None:
self.denoiser = pipeline(
Tasks.acoustic_noise_suppression,
model=zipenhancer_model_path)
else:
self.denoiser = None
print("Warm up VoxCPMModel...")
self.tts_model.generate(
target_text="Hello, this is the first test sentence."
)
@classmethod
def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM-0.5B",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args:
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo").
load_denoiser: Whether to initialize the denoiser pipeline.
zipenhancer_model_id: Denoiser model id or path for ModelScope
acoustic noise suppression.
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
the downloaded snapshot directory.
Raises:
ValueError: If neither a valid ``hf_model_id`` nor a resolvable
``hf_model_id`` is provided.
"""
repo_id = hf_model_id
if not repo_id or repo_id.strip() == "":
raise ValueError("You must provide a valid hf_model_id")
local_path = snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
return cls(
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
)
def _normalize_loudness(self, wav_path: str):
audio, sr = torchaudio.load(wav_path)
loudness = torchaudio.functional.loudness(audio, sr)
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
torchaudio.save(wav_path, normalized_audio, sr)
def generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
cfg_value : float = 2.0,
inference_timesteps : int = 10,
max_length : int = 4096,
normalize : bool = True,
denoise : bool = True,
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
):
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
used for all sub-sentences. Otherwise, the prompt cache is built from
the first generated result and reused for the remaining text chunks.
Args:
text: Input text. Can include newlines; each non-empty line is
treated as a sub-sentence.
prompt_wav_path: Path to a reference audio file for prompting.
prompt_text: Text content corresponding to the prompt audio.
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
max_length: Maximum token length during generation.
normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is
available.
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
Returns:
numpy.ndarray: 1D waveform array (float32) on CPU.
"""
texts = text.split("\n")
texts = [t.strip() for t in texts if t.strip()]
final_wav = []
temp_prompt_wav_path = None
try:
if prompt_wav_path is not None and prompt_text is not None:
if denoise and self.denoiser is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
temp_prompt_wav_path = tmp_file.name
self.denoiser(prompt_wav_path, output_path=temp_prompt_wav_path)
self._normalize_loudness(temp_prompt_wav_path)
prompt_wav_path = temp_prompt_wav_path
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text
)
else:
fixed_prompt_cache = None # will be built from the first inference
for sub_text in texts:
if sub_text.strip() == "":
continue
print("sub_text:", sub_text)
if normalize:
sub_text = self.text_normalizer.normalize(sub_text)
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
target_text=sub_text,
prompt_cache=fixed_prompt_cache,
min_len=2,
max_len=max_length,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
if fixed_prompt_cache is None:
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
original_cache=None,
new_text_token=target_text_token,
new_audio_feat=generated_audio_feat
)
final_wav.append(wav)
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
try:
os.unlink(temp_prompt_wav_path)
except OSError:
pass

View File

@@ -0,0 +1,3 @@
from .voxcpm import VoxCPMModel
__all__ = ["VoxCPMModel"]

122
src/voxcpm/model/utils.py Normal file
View File

@@ -0,0 +1,122 @@
from typing import List
import torch
from transformers import PreTrainedTokenizer
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
This function creates a wrapper around the provided tokenizer that automatically
splits multi-character Chinese tokens into individual characters. This is useful
for ensuring consistent tokenization of Chinese text.
Args:
tokenizer: The base tokenizer to wrap
Returns:
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
Example:
>>> from transformers import LlamaTokenizerFast
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
>>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
>>> tokens = wrapped_tokenizer("你好世界")
"""
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
multichar_tokens = {
token for token in tokenizer.vocab.keys()
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
}
class CharTokenizerWrapper:
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
This wrapper automatically splits multi-character Chinese tokens into
individual characters while preserving the original tokenizer's interface.
"""
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
"""Initialize the wrapper with a base tokenizer.
Args:
base_tokenizer: The tokenizer to wrap
"""
self.tokenizer = base_tokenizer
self.multichar_tokens = multichar_tokens
def tokenize(self, text: str, **kwargs) -> List[str]:
"""Tokenize text and split multi-character Chinese tokens into single characters.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of processed tokens with multi-character Chinese tokens split
Example:
>>> wrapper = CharTokenizerWrapper(tokenizer)
>>> tokens = wrapper.tokenize("你好世界")
>>> # Returns ["", "", "", ""] instead of ["你好", "世界"]
"""
if not isinstance(text, str):
raise TypeError(f"Expected string input, got {type(text)}")
tokens = self.tokenizer.tokenize(text, **kwargs)
processed = []
for token in tokens:
# Remove possible subword prefix
clean_token = token.replace("", "")
if clean_token in self.multichar_tokens:
# Split multi-character token into single characters
chars = list(clean_token)
processed.extend(chars)
else:
processed.append(token)
return processed
def __call__(self, text: str, **kwargs) -> List[int]:
"""Call the tokenizer and return token IDs.
This method provides the same interface as the original tokenizer
but with multi-character Chinese token handling.
Args:
text: Input text to tokenize
**kwargs: Additional arguments passed to the base tokenizer
Returns:
List of token IDs
Raises:
TypeError: If input is not a string
ValueError: If tokenization fails
"""
try:
tokens = self.tokenize(text, **kwargs)
result = self.tokenizer.convert_tokens_to_ids(tokens)
return result
except Exception as e:
raise ValueError(f"Tokenization failed: {str(e)}") from e
return CharTokenizerWrapper(tokenizer)
def get_dtype(dtype: str):
if dtype == "bfloat16":
return torch.bfloat16
elif dtype == "bf16":
return torch.bfloat16
elif dtype == "float16":
return torch.float16
elif dtype == "fp16":
return torch.float16
elif dtype == "float32":
return torch.float32
elif dtype == "fp32":
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype}")

605
src/voxcpm/model/voxcpm.py Normal file
View File

@@ -0,0 +1,605 @@
"""
VoxCPM: A Tokenizer-free speech generation model
This module contains the main VoxCPM model implementation, including configuration classes
and the core VoxCPMModel for text-to-speech generation.
Copyright 2025 OpenBMB
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torchaudio
from einops import rearrange
from pydantic import BaseModel
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.layers import ScalarQuantizationLayer
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
class VoxCPMEncoderConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
class VoxCPMDitConfig(BaseModel):
hidden_dim: int = 1024
ffn_dim: int = 4096
num_heads: int = 16
num_layers: int = 4
kv_channels: int = None
cfm_config: CfmConfig
class VoxCPMConfig(BaseModel):
lm_config: MiniCPM4Config
patch_size: int = 2
feat_dim: int = 64
residual_lm_num_layers: int = 6
scalar_quantization_latent_dim: int = 256
scalar_quantization_scale: int = 9
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
class VoxCPMModel(nn.Module):
def __init__(
self,
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
):
super().__init__()
self.config = config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
self.device = "cpu"
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
self.audio_end_token = 102
# Residual Acoustic LM
residual_lm_config = config.lm_config.model_copy(deep=True)
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
encoder_config.hidden_size = config.encoder_config.hidden_dim
encoder_config.intermediate_size = config.encoder_config.ffn_dim
encoder_config.num_attention_heads = config.encoder_config.num_heads
encoder_config.num_hidden_layers = config.encoder_config.num_layers
encoder_config.kv_channels = config.encoder_config.kv_channels
encoder_config.vocab_size = 0
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
# Local DiT
decoder_config = config.lm_config.model_copy(deep=True)
decoder_config.hidden_size = config.dit_config.hidden_dim
decoder_config.intermediate_size = config.dit_config.ffn_dim
decoder_config.num_attention_heads = config.dit_config.num_heads
decoder_config.num_hidden_layers = config.dit_config.num_layers
decoder_config.kv_channels = config.dit_config.kv_channels
decoder_config.vocab_size = 0
self.feat_decoder = UnifiedCFM(
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
)
# Projection layers
self.fsq_layer = ScalarQuantizationLayer(
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
# Stop Predictor
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
def optimize(self):
if self.device == "cuda":
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
else:
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
return self
@torch.inference_mode()
def generate(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
):
if len(prompt_wav_path) == 0:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_length = text_token.shape[0]
audio_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
else:
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
],
dim=-1,
)
text_length = text_token.shape[0]
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
text_mask = (
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
)
audio_mask = (
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
@torch.inference_mode()
def build_prompt_cache(
self,
prompt_text: str,
prompt_wav_path: str,
):
"""
Build prompt cache for subsequent fast generation.
Args:
prompt_text: prompt text (required)
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
prompt_cache = {
"text_token": text_token,
"audio_feat": audio_feat,
}
return prompt_cache
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_audio_feat: torch.Tensor,
):
"""
Merge original prompt cache with newly generated content to stabilize voice.
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
"""
if original_cache is None:
return {
"text_token": new_text_token,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"audio_feat": merged_audio_feat,
}
return merged_cache
@torch.inference_mode()
def generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
):
"""
Generate audio using pre-built prompt cache.
Args:
target_text: Text to convert to speech
prompt_cache: Cache built by build_prompt_cache (can be None)
min_len: Minimum audio length to avoid very short audio
max_len: Maximum audio length
inference_timesteps: Number of diffusion sampling steps
cfg_value: Classifier-free guidance value
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
Returns:
tuple: (decoded audio tensor, new text tokens, new audio features)
"""
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
text_token = torch.cat(
[
text_token,
torch.tensor(
[self.audio_start_token],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
audio_pad_feat = torch.zeros(
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
dtype=torch.float32,
device=text_token.device,
)
text_token = torch.cat([text_token, text_pad_token])
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
text_token,
text_mask,
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
return (
decode_audio,
target_text_token,
pred_audio_feat
)
@torch.inference_mode()
def inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
feat: torch.Tensor,
feat_mask: torch.Tensor,
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
using the language model and diffusion transformer.
Args:
text: Input text tokens
text_mask: Mask for text tokens
feat: Input audio features
feat_mask: Mask for audio features
min_len: Minimum generation length
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
Returns:
Tuple containing:
- Predicted latent features
- Predicted audio feature sequence
"""
B, T, P, D = feat.shape
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
feat_embed = self.enc_to_lm_proj(feat_embed)
if self.config.lm_config.use_mup:
scale_emb = self.config.lm_config.scale_emb
else:
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
prefix_feat_cond = feat[:, -1, ...] # b, p, d
pred_feat_seq = [] # b, t, p, d
curr_embed = None
enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed,
is_causal=True,
)
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = enc_outputs[:, -1, :]
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
is_causal=True,
)
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
residual_hidden = residual_enc_outputs[:, -1, :]
for i in tqdm(range(max_len)):
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
pred_feat = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
n_timesteps=inference_timesteps,
cfg_value=cfg_value,
).transpose(
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
lm_hidden = self.base_lm.forward_step(
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
).clone()
lm_hidden = self.fsq_layer(lm_hidden)
residual_hidden = self.residual_lm.forward_step(
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
return feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae)
lm_dtype = get_dtype(config.dtype)
model = model.to(lm_dtype)
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
return model.to(model.device).eval().optimize()

View File

View File

@@ -0,0 +1 @@
from .audio_vae import AudioVAE

View File

@@ -0,0 +1,359 @@
import math
from typing import List, Union
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
def forward(self, x):
x_pad = F.pad(x, (self.__padding * 2, 0))
return super().forward(x_pad)
class CausalTransposeConv1d(nn.ConvTranspose1d):
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
super().__init__(*args, **kwargs)
self.__padding = padding
self.__output_padding = output_padding
def forward(self, x):
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
def WNCausalConv1d(*args, **kwargs):
return weight_norm(CausalConv1d(*args, **kwargs))
def WNCausalTransposeConv1d(*args, **kwargs):
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
# Scripting this brings model speed up 1.4x
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
def forward(self, x):
return snake(x, self.alpha)
def init_weights(m):
if isinstance(m, nn.Conv1d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class CausalResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim),
WNCausalConv1d(
dim,
dim,
kernel_size=kernel,
dilation=dilation,
padding=pad,
groups=groups,
),
Snake1d(dim),
WNCausalConv1d(dim, dim, kernel_size=1),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
assert pad == 0
if pad > 0:
x = x[..., pad:-pad]
return x + y
class CausalEncoderBlock(nn.Module):
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
super().__init__()
input_dim = input_dim or output_dim // 2
self.block = nn.Sequential(
CausalResidualUnit(input_dim, dilation=1, groups=groups),
CausalResidualUnit(input_dim, dilation=3, groups=groups),
CausalResidualUnit(input_dim, dilation=9, groups=groups),
Snake1d(input_dim),
WNCausalConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
),
)
def forward(self, x):
return self.block(x)
class CausalEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
latent_dim: int = 32,
strides: list = [2, 4, 8, 8],
depthwise: bool = False,
):
super().__init__()
# Create first convolution
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
groups = d_model // 2 if depthwise else 1
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
groups = d_model if depthwise else 1
# Create two convolution, for mu and logvar
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
hidden_state = self.block(x)
return {
"hidden_state": hidden_state,
"mu": self.fc_mu(hidden_state),
"logvar": self.fc_logvar(hidden_state),
}
class NoiseBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
B, C, T = x.shape
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
h = self.linear(x)
n = noise * h
x = x + n
return x
class CausalDecoderBlock(nn.Module):
def __init__(
self,
input_dim: int = 16,
output_dim: int = 8,
stride: int = 1,
groups=1,
use_noise_block: bool = False,
):
super().__init__()
layers = [
Snake1d(input_dim),
WNCausalTransposeConv1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
),
]
if use_noise_block:
layers.append(NoiseBlock(output_dim))
layers.extend(
[
CausalResidualUnit(output_dim, dilation=1, groups=groups),
CausalResidualUnit(output_dim, dilation=3, groups=groups),
CausalResidualUnit(output_dim, dilation=9, groups=groups),
]
)
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class TransposeLastTwoDim(torch.nn.Module):
def forward(self, x):
return torch.transpose(x, -1, -2)
class CausalDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
depthwise: bool = False,
d_out: int = 1,
use_noise_block: bool = False,
):
super().__init__()
# Add first conv layer
if depthwise:
layers = [
WNCausalConv1d(
input_channel,
input_channel,
kernel_size=7,
padding=3,
groups=input_channel,
),
WNCausalConv1d(input_channel, channels, kernel_size=1),
]
else:
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
# Add upsampling + MRF blocks
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
groups = output_dim if depthwise else 1
layers += [
CausalDecoderBlock(
input_dim,
output_dim,
stride,
groups=groups,
use_noise_block=use_noise_block,
)
]
# Add final conv layer
layers += [
Snake1d(output_dim),
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class AudioVAE(nn.Module):
"""
Args:
"""
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 5, 8, 8],
latent_dim: int = 64,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 5, 2],
depthwise: bool = True,
sample_rate: int = 16000,
use_noise_block: bool = False,
):
super().__init__()
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
self.depthwise = depthwise
self.use_noise_block = use_noise_block
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = CausalEncoder(
encoder_dim,
latent_dim,
encoder_rates,
depthwise=depthwise,
)
self.decoder = CausalDecoder(
latent_dim,
decoder_dim,
decoder_rates,
depthwise=depthwise,
use_noise_block=use_noise_block,
)
self.sample_rate = sample_rate
self.chunk_size = math.prod(encoder_rates)
def preprocess(self, audio_data, sample_rate):
if sample_rate is None:
sample_rate = self.sample_rate
assert sample_rate == self.sample_rate
pad_to = self.hop_length
length = audio_data.shape[-1]
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def decode(self, z: torch.Tensor):
"""Decode given latent codes and return audio data
Parameters
----------
z : Tensor[B x D x T]
Quantized continuous representation of input
length : int, optional
Number of samples in output audio, by default None
Returns
-------
dict
A dictionary with the following keys:
"audio" : Tensor[B x 1 x length]
Decoded audio data.
"""
return self.decoder(z)
def encode(self, audio_data: torch.Tensor, sample_rate: int):
"""
Args:
audio_data: Tensor[B x 1 x T]
sample_rate: int
Returns:
z: Tensor[B x D x T]
"""
if audio_data.ndim == 2:
audio_data = audio_data.unsqueeze(1)
audio_data = self.preprocess(audio_data, sample_rate)
return self.encoder(audio_data)["mu"]

View File

@@ -0,0 +1 @@
from .scalar_quantization_layer import ScalarQuantizationLayer

View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class ScalarQuantizationLayer(nn.Module):
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.scale = scale
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
if self.training:
quantized = torch.round(hidden * self.scale) / self.scale
hidden = hidden + (quantized - hidden).detach()
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden)

View File

@@ -0,0 +1,2 @@
from .unified_cfm import UnifiedCFM, CfmConfig
from .local_dit import VoxCPMLocDiT

View File

@@ -0,0 +1,114 @@
import torch
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
import torch.nn as nn
import math
class SinusoidalPosEmb(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
out_dim: int = None,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
def forward(self, sample):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class VoxCPMLocDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
config: MiniCPM4Config,
in_channels: int = 64,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.config = config
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
self.time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
self.delta_time_mlp = TimestepEmbedding(
in_channels=config.hidden_size,
time_embed_dim=config.hidden_size,
)
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
self.decoder = MiniCPMModel(config)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond: torch.Tensor,
dt: torch.Tensor,
):
"""
Forward pass of DiT.
x: (N, C, T) tensor of inputs
mu: (N, C) tensor of hidden embedding
t: (N,) tensor of diffusion timesteps
cond: (N, C, T') tensor of prefix conditions
dt: (N,) used for mean velocity (may be supported in the future...)
"""
x = self.in_proj(x.transpose(1, 2).contiguous())
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
prefix = cond.size(1)
t = self.time_embeddings(t).to(x.dtype)
t = self.time_mlp(t)
dt = self.time_embeddings(dt).to(x.dtype)
dt = self.delta_time_mlp(dt)
t = t + dt
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
hidden, _ = self.decoder(x, is_causal=False)
hidden = hidden[:, prefix + 1 :, :]
hidden = self.out_proj(hidden)
return hidden.transpose(1, 2).contiguous()

View File

@@ -0,0 +1,137 @@
import torch
from typing import List
from .local_dit import VoxCPMLocDiT
import math
from pydantic import BaseModel
class CfmConfig(BaseModel):
sigma_min: float = 1e-06
solver: str = "euler"
t_scheduler: str = "log-norm"
class UnifiedCFM(torch.nn.Module):
def __init__(
self,
in_channels,
cfm_params: CfmConfig,
estimator: VoxCPMLocDiT,
mean_mode: bool = False,
):
super().__init__()
self.solver = cfm_params.solver
self.sigma_min = cfm_params.sigma_min
self.t_scheduler = cfm_params.t_scheduler
self.in_channels = in_channels
self.mean_mode = mean_mode
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(
self,
mu: torch.Tensor,
n_timesteps: int,
patch_size: int,
cond: torch.Tensor,
temperature: float = 1.0,
cfg_value: float = 1.0,
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
n_timesteps (int): number of diffusion steps
cond: Not used but kept for future purposes
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
b, c = mu.shape
t = patch_size
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# Sway sampling strategy
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
def optimized_scale(self, positive_flat, negative_flat):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
st_star = dot_product / squared_norm
return st_star
def solve_euler(
self,
x: torch.Tensor,
t_span: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor,
cfg_value: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
cond: Not used but kept for future purposes
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
"""
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
sol = []
zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = 0.
else:
# Classifier-Free Guidance inference introduced in VoiceBox
b = x.size(0)
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
x_in[:b], x_in[b:] = x, x
mu_in[:b] = mu
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
# not used now
if not self.mean_mode:
dt_in = torch.zeros_like(dt_in)
cond_in[:b], cond_in[b:] = cond, cond
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
if use_cfg_zero_star:
positive_flat = dphi_dt.view(b, -1)
negative_flat = cfg_dphi_dt.view(b, -1)
st_star = self.optimized_scale(positive_flat, negative_flat)
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
else:
st_star = 1.0
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
x = x - dt * dphi_dt
t = t - dt
sol.append(x)
if step < len(t_span) - 1:
dt = t - t_span[step + 1]
return sol[-1]

View File

@@ -0,0 +1 @@
from .local_encoder import VoxCPMLocEnc

View File

@@ -0,0 +1,30 @@
import torch
import torch.nn as nn
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
from einops import rearrange
class VoxCPMLocEnc(nn.Module):
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
super().__init__()
self.config = config
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
self.encoder = MiniCPMModel(config)
def forward(self, x):
"""
x: [B, T, P, D]
"""
B, T, P, D = x.shape
x = self.in_proj(x)
special_tokens = self.special_token.expand(B, T, 1, -1)
x = torch.cat([special_tokens, x], dim=2)
x = rearrange(x, "b t p c -> (b t) p c")
outputs, _ = self.encoder(x, is_causal=False)
cls_output = outputs[:, 0, :]
return rearrange(cls_output, "(b t) c -> b t c", b=B)

View File

@@ -0,0 +1,3 @@
from .config import MiniCPM4Config
from .model import MiniCPMModel
from .cache import StaticKVCache

View File

@@ -0,0 +1,47 @@
from typing import List, Tuple
import torch
class StaticKVCache:
def __init__(
self,
num_layers: int,
num_kv_heads: int,
dim_kv_head: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
max_length: int = 8192,
):
self.max_length = max_length
self.num_layers = num_layers
self.kv_cache = torch.zeros(
2,
num_layers,
batch_size,
num_kv_heads,
max_length,
dim_kv_head,
device=device,
dtype=dtype,
)
self.current_length = 0
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
def step(self) -> int:
if self.current_length >= self.max_length:
raise ValueError("KV cache is full")
ret = self.current_length
self.current_length += 1
return ret
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
self.current_length = kv_caches[0][0].size(2)
self.kv_cache.zero_()
for i in range(self.num_layers):
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]

View File

@@ -0,0 +1,29 @@
from pydantic import BaseModel
from typing import List
class RopeScalingConfig(BaseModel):
type: str
long_factor: List[float]
short_factor: List[float]
original_max_position_embeddings: int
class MiniCPM4Config(BaseModel):
bos_token_id: int
eos_token_id: int
hidden_size: int
intermediate_size: int
max_position_embeddings: int
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_scaling: RopeScalingConfig
vocab_size: int
use_mup: bool = True
scale_emb: float
dim_model_base: int
scale_depth: float
rope_theta: float
kv_channels: int = None

View File

@@ -0,0 +1,411 @@
from .config import MiniCPM4Config
import torch
import torch.nn as nn
from typing import List, Tuple
import math
from .cache import StaticKVCache
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
old_dtype = hidden.dtype
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
return hidden * weight
class MiniCPMRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MiniCPMRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""
Args:
q: Tensor(batch_size, num_heads, seq_len, head_dim)
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
cos: Tensor(seq_len, head_dim)
sin: Tensor(seq_len, head_dim)
Returns:
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
"""
orig_dtype = q.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class MiniCPMLongRoPE(nn.Module):
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, config: MiniCPM4Config):
super().__init__()
self.config = config
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
self.base = config.rope_theta
self.max_position_embeddings = config.max_position_embeddings
self.short_factor = config.rope_scaling.short_factor
self.long_factor = config.rope_scaling.long_factor
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
self.scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
)
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = 0
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
"""设置cos和sin缓存"""
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
freqs = torch.mul(
torch.outer(t, 1.0 / ext_factors).to(device=device),
self.inv_freq.to(device=device).to(dtype)
)
# 创建embeddings
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
Returns:
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
"""
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return cos, sin
class MiniCPMAttention(nn.Module):
def __init__(self, config: MiniCPM4Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = 10000.0
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
is_causal: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
is_causal=is_causal,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
past_key_value = (key_states, value_states)
return attn_output, past_key_value
def forward_step(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
position_id: int,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
bsz, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_cache, value_cache = kv_cache
key_cache[:, :, position_id, :] = key_states
value_cache[:, :, position_id, :] = value_states
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_cache,
value_cache,
attn_mask=attn_mask,
enable_gqa=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class MiniCPMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MiniCPMDecoderLayer(nn.Module):
def __init__(self, config: MiniCPM4Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
self.mlp = MiniCPMMLP(config)
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.scale_depth = config.scale_depth
self.num_hidden_layers = config.num_hidden_layers
self.use_mup = config.use_mup
def forward(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
is_causal: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
is_causal (`bool`): whether the attention mask is causal
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, present_key_value = self.self_attn(
hidden_states=hidden_states,
position_emb=position_emb,
is_causal=is_causal,
)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
return hidden_states, present_key_value
def forward_step(
self,
hidden_states: torch.Tensor,
position_emb: Tuple[torch.Tensor, torch.Tensor],
position_id: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states = self.self_attn.forward_step(
hidden_states=hidden_states,
position_emb=position_emb,
position_id=position_id,
kv_cache=kv_cache,
)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.use_mup:
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
else:
hidden_states = residual + hidden_states
return hidden_states
class MiniCPMModel(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
Args:
config: MiniCPMConfig
"""
def __init__(self, config: MiniCPM4Config):
super().__init__()
self.vocab_size = config.vocab_size
self.config = config
if config.vocab_size > 0:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
else:
self.embed_tokens = nn.Identity()
self.layers = nn.ModuleList(
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rope_emb = MiniCPMLongRoPE(config)
self.kv_cache = None
def forward(
self,
inputs_embeds: torch.Tensor,
is_causal: bool = True,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Args:
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
is_causal: bool, whether the attention mask is causal
Returns:
hidden_states: Tensor(batch_size, seq_length, hidden_size)
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
"""
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
position_emb = self.rope_emb(position_ids)
hidden_states = inputs_embeds
next_decoder_cache = []
for decoder_layer in self.layers:
hidden_states, this_cache = decoder_layer(
hidden_states,
position_emb,
is_causal,
)
next_decoder_cache.append(this_cache)
hidden_states = self.norm(hidden_states)
return hidden_states, next_decoder_cache
def forward_step(
self,
inputs_embeds: torch.Tensor,
position_id: torch.Tensor,
) -> torch.Tensor:
"""
Args:
inputs_embeds: Tensor(batch_size, hidden_size)
Returns:
hidden_states: Tensor(batch_size, hidden_size)
"""
assert self.kv_cache is not None, "KV cache is not setup"
position_emb = self.rope_emb(position_id)
hidden_states = inputs_embeds
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer.forward_step(
hidden_states,
position_emb,
position_id,
self.kv_cache.get_layer_cache(i),
)
hidden_states = self.norm(hidden_states)
return hidden_states
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
self.kv_cache = StaticKVCache(
num_layers=self.config.num_hidden_layers,
num_kv_heads=self.config.num_key_value_heads,
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
batch_size=batch_size,
device=device,
dtype=dtype,
max_length=max_length,
)

View File

@@ -0,0 +1,244 @@
# some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
import re
import regex
import inflect
from functools import partial
from tn.chinese.normalizer import Normalizer as ZhNormalizer
from tn.english.normalizer import Normalizer as EnNormalizer
def normal_cut_sentence(text):
# 先处理括号内的逗号,将其替换为特殊标记
text = re.sub(r'([(][^)]*)([,])([^)]*[)])', r'\1&&&\3', text)
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
text = re.sub('(\{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
text = re.sub('([. ,。!;?\?\.{6}\{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
# 处理英文句子的分隔
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
text = re.sub(r'([(][^)]*)(&&&)([^)]*[)])', r'\1\3', text)
text = [t for t in text.split("\n") if t]
return text
def cut_sentence_with_fix_length(text : str, length : int):
sentences = normal_cut_sentence(text)
cur_length = 0
res = ""
for sentence in sentences:
if not sentence:
continue
if cur_length > length or cur_length + len(sentence) > length:
yield res
res = ""
cur_length = 0
res += sentence
cur_length += len(sentence)
if res:
yield res
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
# whether contain chinese character
def contains_chinese(text):
return bool(chinese_char_pattern.search(text))
# replace special symbol
def replace_corner_mark(text):
text = text.replace('²', '平方')
text = text.replace('³', '立方')
text = text.replace('', '根号')
text = text.replace('', '约等于')
text = text.replace('<', '小于')
return text
# remove meaningless symbol
def remove_bracket(text):
text = text.replace('', ' ').replace('', ' ')
text = text.replace('', ' ').replace('', ' ')
text = text.replace('`', '').replace('`', '')
text = text.replace("——", " ")
return text
# spell Arabic numerals
def spell_out_number(text: str, inflect_parser):
new_text = []
st = None
for i, c in enumerate(text):
if not c.isdigit():
if st is not None:
num_str = inflect_parser.number_to_words(text[st: i])
new_text.append(num_str)
st = None
new_text.append(c)
else:
if st is None:
st = i
if st is not None and st < len(text):
num_str = inflect_parser.number_to_words(text[st:])
new_text.append(num_str)
return ''.join(new_text)
# split paragrah logic
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
# 2. cal sentence len according to lang
# 3. split sentence according to puncatation
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
def calc_utt_length(_text: str):
if lang == "zh":
return len(_text)
else:
return len(tokenize(_text))
def should_merge(_text: str):
if lang == "zh":
return len(_text) < merge_len
else:
return len(tokenize(_text)) < merge_len
if lang == "zh":
pounc = ['', '', '', '', '', '', '.', '?', '!', ';']
else:
pounc = ['.', '?', '!', ';', ':']
if comma_split:
pounc.extend(['', ','])
st = 0
utts = []
for i, c in enumerate(text):
if c in pounc:
if len(text[st: i]) > 0:
utts.append(text[st: i] + c)
if i + 1 < len(text) and text[i + 1] in ['"', '']:
tmp = utts.pop(-1)
utts.append(tmp + text[i + 1])
st = i + 2
else:
st = i + 1
if len(utts) == 0:
if lang == "zh":
utts.append(text + '')
else:
utts.append(text + '.')
final_utts = []
cur_utt = ""
for utt in utts:
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
final_utts.append(cur_utt)
cur_utt = ""
cur_utt = cur_utt + utt
if len(cur_utt) > 0:
if should_merge(cur_utt) and len(final_utts) != 0:
final_utts[-1] = final_utts[-1] + cur_utt
else:
final_utts.append(cur_utt)
return final_utts
# remove blank between chinese character
def replace_blank(text: str):
out_str = []
for i, c in enumerate(text):
if c == " ":
if ((text[i + 1].isascii() and text[i + 1] != " ") and
(text[i - 1].isascii() and text[i - 1] != " ")):
out_str.append(c)
else:
out_str.append(c)
return "".join(out_str)
def clean_markdown(md_text: str) -> str:
# 去除代码块 ``` ```(包括多行)
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
# 去除内联代码 `code`
md_text = re.sub(r"`[^`]*`", "", md_text)
# 去除图片语法 ![alt](url)
md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
# 去除链接但保留文本 [text](url) -> text
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
# 替换无序列表符号
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
# 去除HTML标签
md_text = re.sub(r"<[^>]+>", "", md_text)
# 去除标题符号(#
md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
# 去除多余空格和空行
md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
md_text = md_text.strip()
return md_text
def clean_text(text):
# 去除 Markdown 语法
text = clean_markdown(text)
# 匹配并移除表情符号
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
# 去除换行符
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace('"', "\")
return text
class TextNormalizer:
def __init__(self, tokenizer=None):
self.tokenizer = tokenizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True)
self.en_tn_model = EnNormalizer()
self.inflect_parser = inflect.engine()
def normalize(self, text, split=False):
# 去除 Markdown 语法,去除表情符号,去除换行符
lang = "zh" if contains_chinese(text) else "en"
text = clean_text(text)
if lang == "zh":
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = replace_blank(text)
text = replace_corner_mark(text)
text = remove_bracket(text)
text = re.sub(r'[,]+$', '', text)
else:
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
if split is False:
return text
if __name__ == "__main__":
text_normalizer = TextNormalizer()
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
其中aaa、bbb 和 ccc 是常数xxx 是变量。这个方程的解可以通过求根公式来找到。
一元二次方程的解法有几种:
- 因式分解法通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式解决方程的解。比如方程x25x+6=0x^2 - 5x + 6 = 0x25x+6=0可以因式分解为(x2)(x3)=0(x - 2)(x - 3) = 0(x2)(x3)=0因此根为2和3。
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
公式x=b±b24ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2ab±b24ac这个公式可以帮助我们求解任何一元二次方程的根。
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b24ac\Delta = b^2 - 4acΔ=b24ac 根据判别式的值,我们可以知道:
- 如果 Δ>0\Delta > 0Δ>0方程有两个不相等的实数解。这是因为判别式大于0时根号内的值是正数所以我们可以得到两个不同的解。
- 如果 Δ=0\Delta = 0Δ=0方程有一个实数解。这是因为根号内的值为零导致两个解相等也就是说方程有一个解。
- 如果 Δ<0\Delta < 0Δ<0方程没有实数解。这意味着根号内的值是负数无法进行实数运算因此方程没有实数解可能有复数解。"""
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战", "550 + 320 = 870千卡。", "解一元二次方程3x^2+x-2=0", "你好啊"]
texts = [text]
for text in texts:
text = text_normalizer.normalize(text)
print(text)
for t in cut_sentence_with_fix_length(text, 15):
print(t)