mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
init
This commit is contained in:
5
src/voxcpm/__init__.py
Normal file
5
src/voxcpm/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .core import VoxCPM
|
||||
|
||||
__all__ = [
|
||||
"VoxCPM",
|
||||
]
|
||||
292
src/voxcpm/cli.py
Normal file
292
src/voxcpm/cli.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VoxCPM Command Line Interface
|
||||
|
||||
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
||||
|
||||
Usage examples:
|
||||
# Direct synthesis (single sample)
|
||||
voxcpm --text "Hello world" --output output.wav
|
||||
|
||||
# Voice cloning (with reference audio and text)
|
||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
|
||||
|
||||
# Batch processing (each line in the file is one sample)
|
||||
voxcpm --input texts.txt --output-dir ./outputs/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||
"""Validate that a file exists."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
||||
return path
|
||||
|
||||
|
||||
def validate_output_path(output_path: str) -> Path:
|
||||
"""Validate the output path and create parent directories if needed."""
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
"""Load VoxCPM model.
|
||||
|
||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
||||
"""
|
||||
print("Loading VoxCPM model...")
|
||||
|
||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
model = VoxCPM(
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (local): {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||
try:
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
|
||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (from_pretrained): {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
"""Voice cloning command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_audio:
|
||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_text:
|
||||
print("Error: Voice cloning requires a reference text (--prompt-text)")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate files
|
||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||
output_path = validate_output_path(args.output)
|
||||
|
||||
# Load model
|
||||
model = load_model(args)
|
||||
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
print(f"Reference audio: {prompt_audio_path}")
|
||||
print(f"Reference text: {args.prompt_text}")
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=str(prompt_audio_path),
|
||||
prompt_text=args.prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
def cmd_synthesize(args):
|
||||
"""Direct TTS synthesis command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
sys.exit(1)
|
||||
# Validate output path
|
||||
output_path = validate_output_path(args.output)
|
||||
# Load model
|
||||
model = load_model(args)
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=None,
|
||||
prompt_text=None,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=False # 无参考音频时不需要降噪
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
def cmd_batch(args):
|
||||
"""Batch synthesis command."""
|
||||
# Validate input file
|
||||
input_file = validate_file_exists(args.input, "input file")
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
texts = [line.strip() for line in f if line.strip()]
|
||||
except Exception as e:
|
||||
print(f"Failed to read input file: {e}")
|
||||
sys.exit(1)
|
||||
if not texts:
|
||||
print("Error: Input file is empty or contains no valid lines")
|
||||
sys.exit(1)
|
||||
print(f"Found {len(texts)} lines to process")
|
||||
|
||||
model = load_model(args)
|
||||
prompt_audio_path = None
|
||||
if args.prompt_audio:
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||
|
||||
success_count = 0
|
||||
for i, text in enumerate(texts, 1):
|
||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
|
||||
|
||||
try:
|
||||
audio_array = model.generate(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_audio_path,
|
||||
prompt_text=args.prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise and prompt_audio_path is not None
|
||||
)
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
sf.write(str(output_file), audio_array, 16000)
|
||||
|
||||
duration = len(audio_array) / 16000
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" Failed: {e}")
|
||||
continue
|
||||
|
||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
|
||||
|
||||
def _build_unified_parser():
|
||||
"""Build unified argument parser (no subcommands, route by args)."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Direct synthesis (single sample)
|
||||
voxcpm --text "Hello world" --output out.wav
|
||||
|
||||
# Voice cloning (reference audio + text)
|
||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
|
||||
|
||||
# Batch processing
|
||||
voxcpm --input texts.txt --output-dir ./outs
|
||||
|
||||
# Select model (from Hub)
|
||||
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
|
||||
"""
|
||||
)
|
||||
|
||||
# Task selection (automatic routing by presence of args)
|
||||
parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
|
||||
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
|
||||
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
|
||||
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
|
||||
|
||||
# Prompt audio (for voice cloning)
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
|
||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||
|
||||
# Model loading parameters
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
"""Unified CLI entrypoint: route by provided arguments."""
|
||||
parser = _build_unified_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Routing: prefer batch → single (clone/direct)
|
||||
if args.input:
|
||||
if not args.output_dir:
|
||||
print("Error: Batch mode requires --output-dir")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return cmd_batch(args)
|
||||
|
||||
# Single-sample mode
|
||||
if not args.text or not args.output:
|
||||
print("Error: Single-sample mode requires --text and --output")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# If prompt audio+text provided → voice cloning
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
sys.exit(1)
|
||||
return cmd_clone(args)
|
||||
|
||||
# Otherwise → direct synthesis
|
||||
return cmd_synthesize(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
181
src/voxcpm/core.py
Normal file
181
src/voxcpm/core.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
Args:
|
||||
voxcpm_model_path: Local filesystem path to the VoxCPM model assets
|
||||
(weights, configs, etc.). Typically the directory returned by
|
||||
a prior download step.
|
||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
||||
self.text_normalizer = TextNormalizer()
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
self.denoiser = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
print("Warm up VoxCPMModel...")
|
||||
self.tts_model.generate(
|
||||
target_text="Hello, this is the first test sentence."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM-0.5B",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
Args:
|
||||
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo").
|
||||
load_denoiser: Whether to initialize the denoiser pipeline.
|
||||
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
||||
acoustic noise suppression.
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
|
||||
Returns:
|
||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||
the downloaded snapshot directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither a valid ``hf_model_id`` nor a resolvable
|
||||
``hf_model_id`` is provided.
|
||||
"""
|
||||
repo_id = hf_model_id
|
||||
if not repo_id or repo_id.strip() == "":
|
||||
raise ValueError("You must provide a valid hf_model_id")
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
return cls(
|
||||
voxcpm_model_path=local_path,
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
)
|
||||
|
||||
def _normalize_loudness(self, wav_path: str):
|
||||
audio, sr = torchaudio.load(wav_path)
|
||||
loudness = torchaudio.functional.loudness(audio, sr)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
||||
torchaudio.save(wav_path, normalized_audio, sr)
|
||||
|
||||
def generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
cfg_value : float = 2.0,
|
||||
inference_timesteps : int = 10,
|
||||
max_length : int = 4096,
|
||||
normalize : bool = True,
|
||||
denoise : bool = True,
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
):
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
||||
the first generated result and reused for the remaining text chunks.
|
||||
|
||||
Args:
|
||||
text: Input text. Can include newlines; each non-empty line is
|
||||
treated as a sub-sentence.
|
||||
prompt_wav_path: Path to a reference audio file for prompting.
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
max_length: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
Returns:
|
||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
"""
|
||||
texts = text.split("\n")
|
||||
texts = [t.strip() for t in texts if t.strip()]
|
||||
final_wav = []
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
if denoise and self.denoiser is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
temp_prompt_wav_path = tmp_file.name
|
||||
|
||||
self.denoiser(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||
self._normalize_loudness(temp_prompt_wav_path)
|
||||
prompt_wav_path = temp_prompt_wav_path
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
for sub_text in texts:
|
||||
if sub_text.strip() == "":
|
||||
continue
|
||||
print("sub_text:", sub_text)
|
||||
if normalize:
|
||||
sub_text = self.text_normalizer.normalize(sub_text)
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
target_text=sub_text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
max_len=max_length,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
if fixed_prompt_cache is None:
|
||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
||||
original_cache=None,
|
||||
new_text_token=target_text_token,
|
||||
new_audio_feat=generated_audio_feat
|
||||
)
|
||||
final_wav.append(wav)
|
||||
|
||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
3
src/voxcpm/model/__init__.py
Normal file
3
src/voxcpm/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .voxcpm import VoxCPMModel
|
||||
|
||||
__all__ = ["VoxCPMModel"]
|
||||
122
src/voxcpm/model/utils.py
Normal file
122
src/voxcpm/model/utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
|
||||
This function creates a wrapper around the provided tokenizer that automatically
|
||||
splits multi-character Chinese tokens into individual characters. This is useful
|
||||
for ensuring consistent tokenization of Chinese text.
|
||||
|
||||
Args:
|
||||
tokenizer: The base tokenizer to wrap
|
||||
|
||||
Returns:
|
||||
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
|
||||
|
||||
Example:
|
||||
>>> from transformers import LlamaTokenizerFast
|
||||
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
|
||||
>>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
>>> tokens = wrapped_tokenizer("你好世界")
|
||||
"""
|
||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||
multichar_tokens = {
|
||||
token for token in tokenizer.vocab.keys()
|
||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
}
|
||||
|
||||
class CharTokenizerWrapper:
|
||||
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
|
||||
|
||||
This wrapper automatically splits multi-character Chinese tokens into
|
||||
individual characters while preserving the original tokenizer's interface.
|
||||
"""
|
||||
|
||||
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Initialize the wrapper with a base tokenizer.
|
||||
|
||||
Args:
|
||||
base_tokenizer: The tokenizer to wrap
|
||||
"""
|
||||
self.tokenizer = base_tokenizer
|
||||
self.multichar_tokens = multichar_tokens
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""Tokenize text and split multi-character Chinese tokens into single characters.
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
Returns:
|
||||
List of processed tokens with multi-character Chinese tokens split
|
||||
|
||||
Example:
|
||||
>>> wrapper = CharTokenizerWrapper(tokenizer)
|
||||
>>> tokens = wrapper.tokenize("你好世界")
|
||||
>>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"Expected string input, got {type(text)}")
|
||||
|
||||
tokens = self.tokenizer.tokenize(text, **kwargs)
|
||||
processed = []
|
||||
|
||||
for token in tokens:
|
||||
# Remove possible subword prefix
|
||||
clean_token = token.replace("▁", "")
|
||||
|
||||
if clean_token in self.multichar_tokens:
|
||||
# Split multi-character token into single characters
|
||||
chars = list(clean_token)
|
||||
processed.extend(chars)
|
||||
else:
|
||||
processed.append(token)
|
||||
|
||||
return processed
|
||||
|
||||
def __call__(self, text: str, **kwargs) -> List[int]:
|
||||
"""Call the tokenizer and return token IDs.
|
||||
|
||||
This method provides the same interface as the original tokenizer
|
||||
but with multi-character Chinese token handling.
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a string
|
||||
ValueError: If tokenization fails
|
||||
"""
|
||||
try:
|
||||
tokens = self.tokenize(text, **kwargs)
|
||||
result = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Tokenization failed: {str(e)}") from e
|
||||
|
||||
return CharTokenizerWrapper(tokenizer)
|
||||
|
||||
|
||||
def get_dtype(dtype: str):
|
||||
if dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
elif dtype == "bf16":
|
||||
return torch.bfloat16
|
||||
elif dtype == "float16":
|
||||
return torch.float16
|
||||
elif dtype == "fp16":
|
||||
return torch.float16
|
||||
elif dtype == "float32":
|
||||
return torch.float32
|
||||
elif dtype == "fp32":
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
605
src/voxcpm/model/voxcpm.py
Normal file
605
src/voxcpm/model/voxcpm.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""
|
||||
VoxCPM: A Tokenizer-free speech generation model
|
||||
|
||||
This module contains the main VoxCPM model implementation, including configuration classes
|
||||
and the core VoxCPMModel for text-to-speech generation.
|
||||
|
||||
Copyright 2025 OpenBMB
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
|
||||
class VoxCPMDitConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
cfm_config: CfmConfig
|
||||
|
||||
|
||||
class VoxCPMConfig(BaseModel):
|
||||
lm_config: MiniCPM4Config
|
||||
patch_size: int = 2
|
||||
feat_dim: int = 64
|
||||
residual_lm_num_layers: int = 6
|
||||
scalar_quantization_latent_dim: int = 256
|
||||
scalar_quantization_scale: int = 9
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
self.audio_end_token = 102
|
||||
|
||||
# Residual Acoustic LM
|
||||
residual_lm_config = config.lm_config.model_copy(deep=True)
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
||||
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
||||
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
||||
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
||||
encoder_config.kv_channels = config.encoder_config.kv_channels
|
||||
encoder_config.vocab_size = 0
|
||||
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
||||
|
||||
# Local DiT
|
||||
decoder_config = config.lm_config.model_copy(deep=True)
|
||||
decoder_config.hidden_size = config.dit_config.hidden_dim
|
||||
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
||||
decoder_config.num_attention_heads = config.dit_config.num_heads
|
||||
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
||||
decoder_config.kv_channels = config.dit_config.kv_channels
|
||||
decoder_config.vocab_size = 0
|
||||
self.feat_decoder = UnifiedCFM(
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
self.fsq_layer = ScalarQuantizationLayer(
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
|
||||
# Stop Predictor
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
if self.device == "cuda":
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
return self
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
prompt_wav_path: str = "",
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
|
||||
|
||||
else:
|
||||
text = prompt_text + target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
self,
|
||||
prompt_text: str,
|
||||
prompt_wav_path: str,
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent fast generation.
|
||||
|
||||
Args:
|
||||
prompt_text: prompt text (required)
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Merge original prompt cache with newly generated content to stabilize voice.
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
Args:
|
||||
target_text: Text to convert to speech
|
||||
prompt_cache: Cache built by build_prompt_cache (can be None)
|
||||
min_len: Minimum audio length to avoid very short audio
|
||||
max_len: Maximum audio length
|
||||
inference_timesteps: Number of diffusion sampling steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
"""
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
feat: torch.Tensor,
|
||||
feat_mask: torch.Tensor,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
using the language model and diffusion transformer.
|
||||
|
||||
Args:
|
||||
text: Input text tokens
|
||||
text_mask: Mask for text tokens
|
||||
feat: Input audio features
|
||||
feat_mask: Mask for audio features
|
||||
min_len: Minimum generation length
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
scale_emb = self.config.lm_config.scale_emb
|
||||
else:
|
||||
scale_emb = 1.0
|
||||
|
||||
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
||||
pred_feat_seq = [] # b, t, p, d
|
||||
curr_embed = None
|
||||
|
||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||
inputs_embeds=combined_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
||||
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
|
||||
|
||||
pred_feat = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
||||
n_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
).transpose(
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
|
||||
lm_hidden = self.base_lm.forward_step(
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
0
src/voxcpm/modules/__init__.py
Normal file
0
src/voxcpm/modules/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(
|
||||
input_channel,
|
||||
input_channel,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=input_channel,
|
||||
),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
1
src/voxcpm/modules/layers/__init__.py
Normal file
1
src/voxcpm/modules/layers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ScalarQuantizationLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.scale = scale
|
||||
|
||||
self.in_proj = nn.Linear(in_dim, latent_dim)
|
||||
self.out_proj = nn.Linear(latent_dim, out_dim)
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.in_proj(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
|
||||
if self.training:
|
||||
quantized = torch.round(hidden * self.scale) / self.scale
|
||||
hidden = hidden + (quantized - hidden).detach()
|
||||
else:
|
||||
hidden = torch.round(hidden * self.scale) / self.scale
|
||||
|
||||
return self.out_proj(hidden)
|
||||
2
src/voxcpm/modules/locdit/__init__.py
Normal file
2
src/voxcpm/modules/locdit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
n_timesteps: int,
|
||||
patch_size: int,
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
def solve_euler(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t_span: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: Not used but kept for future purposes
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
|
||||
# not used now
|
||||
if not self.mean_mode:
|
||||
dt_in = torch.zeros_like(dt_in)
|
||||
cond_in[:b], cond_in[b:] = cond, cond
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
st_star = self.optimized_scale(positive_flat, negative_flat)
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
t = t - dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
1
src/voxcpm/modules/locenc/__init__.py
Normal file
1
src/voxcpm/modules/locenc/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .local_encoder import VoxCPMLocEnc
|
||||
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VoxCPMLocEnc(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
||||
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
||||
self.encoder = MiniCPMModel(config)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, T, P, D]
|
||||
"""
|
||||
B, T, P, D = x.shape
|
||||
|
||||
x = self.in_proj(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
outputs, _ = self.encoder(x, is_causal=False)
|
||||
cls_output = outputs[:, 0, :]
|
||||
|
||||
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .config import MiniCPM4Config
|
||||
from .model import MiniCPMModel
|
||||
from .cache import StaticKVCache
|
||||
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
class StaticKVCache:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
dim_kv_head: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_length: int = 8192,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
2,
|
||||
num_layers,
|
||||
batch_size,
|
||||
num_kv_heads,
|
||||
max_length,
|
||||
dim_kv_head,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.current_length = 0
|
||||
|
||||
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
|
||||
|
||||
def step(self) -> int:
|
||||
if self.current_length >= self.max_length:
|
||||
raise ValueError("KV cache is full")
|
||||
|
||||
ret = self.current_length
|
||||
self.current_length += 1
|
||||
return ret
|
||||
|
||||
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
self.current_length = kv_caches[0][0].size(2)
|
||||
self.kv_cache.zero_()
|
||||
for i in range(self.num_layers):
|
||||
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
|
||||
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
|
||||
29
src/voxcpm/modules/minicpm4/config.py
Normal file
29
src/voxcpm/modules/minicpm4/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class RopeScalingConfig(BaseModel):
|
||||
type: str
|
||||
long_factor: List[float]
|
||||
short_factor: List[float]
|
||||
original_max_position_embeddings: int
|
||||
|
||||
|
||||
class MiniCPM4Config(BaseModel):
|
||||
bos_token_id: int
|
||||
eos_token_id: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
max_position_embeddings: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_scaling: RopeScalingConfig
|
||||
vocab_size: int
|
||||
use_mup: bool = True
|
||||
scale_emb: float
|
||||
dim_model_base: int
|
||||
scale_depth: float
|
||||
rope_theta: float
|
||||
kv_channels: int = None
|
||||
411
src/voxcpm/modules/minicpm4/model.py
Normal file
411
src/voxcpm/modules/minicpm4/model.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from .config import MiniCPM4Config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple
|
||||
import math
|
||||
from .cache import StaticKVCache
|
||||
|
||||
|
||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||
old_dtype = hidden.dtype
|
||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
||||
return hidden * weight
|
||||
|
||||
|
||||
class MiniCPMRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
MiniCPMRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
q: Tensor(batch_size, num_heads, seq_len, head_dim)
|
||||
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
cos: Tensor(seq_len, head_dim)
|
||||
sin: Tensor(seq_len, head_dim)
|
||||
Returns:
|
||||
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
"""
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
||||
|
||||
|
||||
class MiniCPMLongRoPE(nn.Module):
|
||||
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
|
||||
self.base = config.rope_theta
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.short_factor = config.rope_scaling.short_factor
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.max_seq_len_cached = 0
|
||||
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
|
||||
else:
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
|
||||
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
|
||||
|
||||
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
|
||||
Returns:
|
||||
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
|
||||
"""
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
class MiniCPMAttention(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = 10000.0
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
return attn_output, past_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: int,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
bsz, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask=attn_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MiniCPMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MiniCPMMLP(config)
|
||||
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.scale_depth = config.scale_depth
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.use_mup = config.use_mup
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
|
||||
is_causal (`bool`): whether the attention mask is causal
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, present_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn.forward_step(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
position_id=position_id,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: MiniCPMConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
if config.vocab_size > 0:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
else:
|
||||
self.embed_tokens = nn.Identity()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rope_emb = MiniCPMLongRoPE(config)
|
||||
|
||||
self.kv_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_causal: bool = True,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
|
||||
is_causal: bool, whether the attention mask is causal
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, seq_length, hidden_size)
|
||||
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
|
||||
"""
|
||||
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
|
||||
position_emb = self.rope_emb(position_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
next_decoder_cache = []
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
|
||||
hidden_states, this_cache = decoder_layer(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
is_causal,
|
||||
)
|
||||
next_decoder_cache.append(this_cache)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states, next_decoder_cache
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_id: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, hidden_size)
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, hidden_size)
|
||||
"""
|
||||
assert self.kv_cache is not None, "KV cache is not setup"
|
||||
|
||||
position_emb = self.rope_emb(position_id)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer.forward_step(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
position_id,
|
||||
self.kv_cache.get_layer_cache(i),
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_length=max_length,
|
||||
)
|
||||
244
src/voxcpm/utils/text_normalize.py
Normal file
244
src/voxcpm/utils/text_normalize.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
|
||||
import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
|
||||
def normal_cut_sentence(text):
|
||||
# 先处理括号内的逗号,将其替换为特殊标记
|
||||
text = re.sub(r'([((][^))]*)([,,])([^))]*[))])', r'\1&&&\3', text)
|
||||
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
|
||||
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
|
||||
text = re.sub('(\…{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
|
||||
text = re.sub('([. ,。!;?\?\.{6}\…{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
|
||||
# 处理英文句子的分隔
|
||||
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
|
||||
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
|
||||
text = re.sub(r'([((][^))]*)(&&&)([^))]*[))])', r'\1,\3', text)
|
||||
text = [t for t in text.split("\n") if t]
|
||||
return text
|
||||
|
||||
|
||||
def cut_sentence_with_fix_length(text : str, length : int):
|
||||
sentences = normal_cut_sentence(text)
|
||||
cur_length = 0
|
||||
res = ""
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if cur_length > length or cur_length + len(sentence) > length:
|
||||
yield res
|
||||
res = ""
|
||||
cur_length = 0
|
||||
res += sentence
|
||||
cur_length += len(sentence)
|
||||
if res:
|
||||
yield res
|
||||
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
return bool(chinese_char_pattern.search(text))
|
||||
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
text = text.replace('√', '根号')
|
||||
text = text.replace('≈', '约等于')
|
||||
text = text.replace('<', '小于')
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', ' ').replace(')', ' ')
|
||||
text = text.replace('【', ' ').replace('】', ' ')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
|
||||
# spell Arabic numerals
|
||||
def spell_out_number(text: str, inflect_parser):
|
||||
new_text = []
|
||||
st = None
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
else:
|
||||
if st is None:
|
||||
st = i
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
||||
# 2. cal sentence len according to lang
|
||||
# 3. split sentence according to puncatation
|
||||
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
||||
def calc_utt_length(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text)
|
||||
else:
|
||||
return len(tokenize(_text))
|
||||
|
||||
def should_merge(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text) < merge_len
|
||||
else:
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
else:
|
||||
st = i + 1
|
||||
if len(utts) == 0:
|
||||
if lang == "zh":
|
||||
utts.append(text + '。')
|
||||
else:
|
||||
utts.append(text + '.')
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
||||
final_utts.append(cur_utt)
|
||||
cur_utt = ""
|
||||
cur_utt = cur_utt + utt
|
||||
if len(cur_utt) > 0:
|
||||
if should_merge(cur_utt) and len(final_utts) != 0:
|
||||
final_utts[-1] = final_utts[-1] + cur_utt
|
||||
else:
|
||||
final_utts.append(cur_utt)
|
||||
|
||||
return final_utts
|
||||
|
||||
|
||||
# remove blank between chinese character
|
||||
def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
def clean_markdown(md_text: str) -> str:
|
||||
# 去除代码块 ``` ```(包括多行)
|
||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||
|
||||
# 去除内联代码 `code`
|
||||
md_text = re.sub(r"`[^`]*`", "", md_text)
|
||||
|
||||
# 去除图片语法 
|
||||
md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
|
||||
|
||||
# 去除链接但保留文本 [text](url) -> text
|
||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||
|
||||
# 替换无序列表符号
|
||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除HTML标签
|
||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||
|
||||
# 去除标题符号(#)
|
||||
md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除多余空格和空行
|
||||
md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
|
||||
md_text = md_text.strip()
|
||||
|
||||
return md_text
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
# 去除 Markdown 语法
|
||||
text = clean_markdown(text)
|
||||
# 匹配并移除表情符号
|
||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
||||
# 去除换行符
|
||||
text = text.replace("\n", " ")
|
||||
text = text.replace("\t", " ")
|
||||
text = text.replace('"', "\“")
|
||||
return text
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
def normalize(self, text, split=False):
|
||||
# 去除 Markdown 语法,去除表情符号,去除换行符
|
||||
lang = "zh" if contains_chinese(text) else "en"
|
||||
text = clean_text(text)
|
||||
if lang == "zh":
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,]+$', '。', text)
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if split is False:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_normalizer = TextNormalizer()
|
||||
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
|
||||
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
|
||||
其中,aaa、bbb 和 ccc 是常数,xxx 是变量。这个方程的解可以通过求根公式来找到。
|
||||
一元二次方程的解法有几种:
|
||||
- 因式分解法:通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式,解决方程的解。比如,方程x2−5x+6=0x^2 - 5x + 6 = 0x2−5x+6=0可以因式分解为(x−2)(x−3)=0(x - 2)(x - 3) = 0(x−2)(x−3)=0,因此根为2和3。
|
||||
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
|
||||
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
|
||||
公式:x=−b±b2−4ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2a−b±b2−4ac这个公式可以帮助我们求解任何一元二次方程的根。
|
||||
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b2−4ac\Delta = b^2 - 4acΔ=b2−4ac 根据判别式的值,我们可以知道:
|
||||
- 如果 Δ>0\Delta > 0Δ>0,方程有两个不相等的实数解。这是因为判别式大于0时,根号内的值是正数,所以我们可以得到两个不同的解。
|
||||
- 如果 Δ=0\Delta = 0Δ=0,方程有一个实数解。这是因为根号内的值为零,导致两个解相等,也就是说方程有一个解。
|
||||
- 如果 Δ<0\Delta < 0Δ<0,方程没有实数解。这意味着根号内的值是负数,无法进行实数运算,因此方程没有实数解,可能有复数解。"""
|
||||
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战,", "550 + 320 = 870千卡。", "解一元二次方程:3x^2+x-2=0", "你好啊"]
|
||||
texts = [text]
|
||||
for text in texts:
|
||||
text = text_normalizer.normalize(text)
|
||||
print(text)
|
||||
for t in cut_sentence_with_fix_length(text, 15):
|
||||
print(t)
|
||||
Reference in New Issue
Block a user