40 Commits

Author SHA1 Message Date
刘鑫
81467f649f Modify lora inference api 2025-12-05 22:29:44 +08:00
刘鑫
400f47a516 Modify lora inference api 2025-12-05 22:22:13 +08:00
Labmem-Zhouyx
b1f7593ae0 Update: default no denoise & normalize 2025-12-05 22:16:27 +08:00
Labmem-Zhouyx
6a5e713698 fix: streaming mode 2025-12-05 22:06:15 +08:00
Labmem-Zhouyx
3443dbb212 Update: VoxCPM1.5 and fine-tuning supprt 2025-12-05 21:04:51 +08:00
Labmem-Zhouyx
d1bb6aaf41 update technical report 2025-09-30 10:47:39 +08:00
刘鑫
2eb4d39719 FX: Add MPS support 2025-09-28 21:06:35 +08:00
刘鑫
fbf8984d4e Merge branch 'main' into dev 2025-09-27 16:20:47 +08:00
刘鑫
41752dc0fa FX: Raising the Python version to avoid issues with Gradio failing to start. 2025-09-22 21:16:23 +08:00
xliucs
b0714adcaa Merge pull request #26 from AbrahamSanders/main
Add a streaming API for VoxCPM
2025-09-22 20:47:07 +08:00
AbrahamSanders
89f4d917a0 Update readme with streaming example 2025-09-19 17:09:30 -04:00
AbrahamSanders
5c5da0dbe6 Add a streaming API for VoxCPM 2025-09-19 16:56:11 -04:00
刘鑫
961569e76d merge from main 2025-09-19 22:08:56 +08:00
刘鑫
5f56d5ff5d FX: update README 2025-09-19 13:44:33 +08:00
xliucs
169c17ddfd Merge pull request #17 from MayDomine/main
add prompt-file option to set prompt text
2025-09-19 13:35:36 +08:00
MayDomine
996c69a1a8 add prompt-file option to set prompt text 2025-09-19 12:53:23 +08:00
刘鑫
dc6b6d1d1c Fx: capture compile error on Windows 2025-09-18 19:23:13 +08:00
刘鑫
cef6aefb3d remove \n from input text 2025-09-18 14:57:45 +08:00
周逸轩
1a46c5d1ad update README 2025-09-18 14:53:37 +08:00
周逸轩
5257ec3dc5 FX: noise point 2025-09-18 14:50:01 +08:00
刘鑫
bdd516b579 remove target text anotation 2025-09-18 13:07:43 +08:00
刘鑫
11568f0776 remove target text anotation 2025-09-18 12:58:27 +08:00
刘鑫
e5bcb735f0 Remove segment text logic 2025-09-18 12:02:37 +08:00
刘鑫
f26a1ea2f7 Remove segment text logic 2025-09-18 12:01:26 +08:00
周逸轩
1fa9e2ca02 update README 2025-09-18 01:21:45 +08:00
周逸轩
10f48ba330 update README 2025-09-17 19:36:32 +08:00
周逸轩
639b2272ab update README 2025-09-17 19:34:08 +08:00
周逸轩
7e8f754ba1 update README 2025-09-17 19:33:37 +08:00
刘鑫
032c7fe403 capture torch compile error 2025-09-17 18:09:09 +08:00
刘鑫
5390a47862 Merge branch 'dev'; Replace the text normalization library 2025-09-16 22:17:30 +08:00
刘鑫
e7012f1a94 Replace the text normalization library 2025-09-16 22:17:14 +08:00
刘鑫
82332cfc99 Replace the text normalization library 2025-09-16 22:17:14 +08:00
刘鑫
605ac2d8e4 Replace the text normalization library 2025-09-16 22:16:40 +08:00
周逸轩
0fa8d894d1 update README 2025-09-16 21:33:57 +08:00
周逸轩
776c0d19fb FX: typo 2025-09-16 19:40:27 +08:00
周逸轩
ed6e6b4dac FX: typo 2025-09-16 19:37:55 +08:00
周逸轩
e3108d4a12 FX: typo 2025-09-16 19:36:17 +08:00
周逸轩
59fe3f30a1 update README 2025-09-16 19:05:00 +08:00
周逸轩
6f2fb45756 ModelScope 2025-09-16 17:12:52 +08:00
周逸轩
91128d823d ModelScope 2025-09-16 17:12:52 +08:00
33 changed files with 3326 additions and 395 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
launch.json
__pycache__
voxcpm.egg-info
.DS_Store

215
README.md
View File

@@ -1,14 +1,27 @@
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning ## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
[![Project Page](https://img.shields.io/badge/Project%20Page-GitHub-blue)](https://github.com/OpenBMB/VoxCPM/) [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OpenBMB-yellow)](https://huggingface.co/openbmb/VoxCPM-0.5B) [![Live Playground](https://img.shields.io/badge/Live%20PlayGround-Demo-orange)](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [![Samples](https://img.shields.io/badge/Page-Samples-red)](https://thuhcsi.github.io/VoxCPM/) [![Project Page](https://img.shields.io/badge/Project%20Page-GitHub-blue)](https://github.com/OpenBMB/VoxCPM/) [![Technical Report](https://img.shields.io/badge/Technical%20Report-Arxiv-red)](https://arxiv.org/abs/2509.24650)[![Live Playground](https://img.shields.io/badge/Live%20PlayGround-Demo-orange)](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [![Samples](https://img.shields.io/badge/Audio%20Samples-Page-green)](https://openbmb.github.io/VoxCPM-demopage)
#### VoxCPM1.5 Model Weights
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OpenBMB-yellow)](https://huggingface.co/openbmb/VoxCPM1.5) [![ModelScope](https://img.shields.io/badge/ModelScope-OpenBMB-purple)](https://modelscope.cn/models/OpenBMB/VoxCPM1.5)
<div align="center"> <div align="center">
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%"> <img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%">
</div> </div>
<div align="center">
👋 Contact us on [WeChat](assets/wechat.png)
</div>
## News ## News
* [2025.12.05] 🎉 🎉 🎉 We Open Source the VoxCPM1.5 [weights](https://huggingface.co/openbmb/VoxCPM1.5)! The model now supports both full-parameter fine-tuning and efficient LoRA fine-tuning, empowering you to create your own tailored version. See [Release Notes](docs/release_note.md) for details.
* [2025.09.30] 🔥 🔥 🔥 We Release VoxCPM [Technical Report](https://arxiv.org/abs/2509.24650)!
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)! * [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)!
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now! * [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
@@ -25,15 +38,22 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
### 🚀 Key Features ### 🚀 Key Features
- **Context-Aware, Expressive Speech Generation** - VoxCPM comprehends text to infer and generate appropriate prosody, delivering speech with remarkable expressiveness and natural flow. It spontaneously adapts speaking style based on content, producing highly fitting vocal expression trained on a massive 1.8 million-hour bilingual corpus. - **Context-Aware, Expressive Speech Generation** - VoxCPM comprehends text to infer and generate appropriate prosody, delivering speech with remarkable expressiveness and natural flow. It spontaneously adapts speaking style based on content, producing highly fitting vocal expression trained on a massive 1.8 million-hour bilingual corpus.
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speakers timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica. - **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speaker's timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
- **High-Efficiency Synthesis** - VoxCPM supports streaming synthesis with a Real-Time Factor (RTF) as low as 0.17 on a consumer-grade NVIDIA RTX 4090 GPU, making it possible for real-time applications. - **High-Efficiency Synthesis** - VoxCPM supports streaming synthesis with a Real-Time Factor (RTF) as low as 0.17 on a consumer-grade NVIDIA RTX 4090 GPU, making it possible for real-time applications.
### 📦 Model Versions
See [Release Notes](docs/release_note.md) for details
- **VoxCPM1.5** (Latest):
- Model Params: 750M
- Sampling rate of AudioVAE: 44100
- Token rate in LM Backbone: 6.25Hz (patch-size=4)
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
- **VoxCPM-0.5B** (Original):
- Model Params: 600M
- Sampling rate of AudioVAE: 16000
- Token rate in LM Backbone: 12.5Hz (patch-size=2)
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17
@@ -45,12 +65,18 @@ pip install voxcpm
``` ```
### 1. Model Download (Optional) ### 1. Model Download (Optional)
By default, when you first run the script, the model will be downloaded automatically, but you can also download the model in advance. By default, when you first run the script, the model will be downloaded automatically, but you can also download the model in advance.
- Download VoxCPM-0.5B - Download VoxCPM1.5
``` ```
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only) snapshot_download("openbmb/VoxCPM1.5")
``` ```
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
- Or Download VoxCPM-0.5B
```
from huggingface_hub import snapshot_download
snapshot_download("openbmb/VoxCPM-0.5B")
```
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
``` ```
from modelscope import snapshot_download from modelscope import snapshot_download
snapshot_download('iic/speech_zipenhancer_ans_multiloss_16k_base') snapshot_download('iic/speech_zipenhancer_ans_multiloss_16k_base')
@@ -60,25 +86,39 @@ By default, when you first run the script, the model will be downloaded automati
### 2. Basic Usage ### 2. Basic Usage
```python ```python
import soundfile as sf import soundfile as sf
import numpy as np
from voxcpm import VoxCPM from voxcpm import VoxCPM
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B") model = VoxCPM.from_pretrained("openbmb/VoxCPM1.5")
# Non-streaming
wav = model.generate( wav = model.generate(
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.", text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
prompt_text=None, # optional: reference text prompt_text=None, # optional: reference text
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
normalize=True, # enable external TN tool normalize=False, # enable external TN tool, but will disable native raw text support
denoise=True, # enable external Denoise tool denoise=False, # enable external Denoise tool, but it may cause some distortion and restrict the sampling rate to 16kHz
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable) retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
retry_badcase_max_times=3, # maximum retrying times retry_badcase_max_times=3, # maximum retrying times
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
) )
sf.write("output.wav", wav, 16000) sf.write("output.wav", wav, model.tts_model.sample_rate)
print("saved: output.wav") print("saved: output.wav")
# Streaming
chunks = []
for chunk in model.generate_streaming(
text = "Streaming text to speech is easy with VoxCPM!",
# supports same args as above
):
chunks.append(chunk)
wav = np.concatenate(chunks)
sf.write("output_streaming.wav", wav, model.tts_model.sample_rate)
print("saved: output_streaming.wav")
``` ```
### 3. CLI Usage ### 3. CLI Usage
@@ -94,7 +134,14 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
--prompt-audio path/to/voice.wav \ --prompt-audio path/to/voice.wav \
--prompt-text "reference transcript" \ --prompt-text "reference transcript" \
--output out.wav \ --output out.wav \
--denoise # --denoise
# (Optinal) Voice cloning (reference audio + transcript file)
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
--prompt-audio path/to/voice.wav \
--prompt-file "/path/to/text-file" \
--output out.wav \
# --denoise
# 3) Batch processing (one text per line) # 3) Batch processing (one text per line)
voxcpm --input examples/input.txt --output-dir outs voxcpm --input examples/input.txt --output-dir outs
@@ -102,7 +149,7 @@ voxcpm --input examples/input.txt --output-dir outs
voxcpm --input examples/input.txt --output-dir outs \ voxcpm --input examples/input.txt --output-dir outs \
--prompt-audio path/to/voice.wav \ --prompt-audio path/to/voice.wav \
--prompt-text "reference transcript" \ --prompt-text "reference transcript" \
--denoise # --denoise
# 4) Inference parameters (quality/speed) # 4) Inference parameters (quality/speed)
voxcpm --text "..." --output out.wav \ voxcpm --text "..." --output out.wav \
@@ -113,7 +160,7 @@ voxcpm --text "..." --output out.wav \
voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir
# Or from Hugging Face (auto download/cache) # Or from Hugging Face (auto download/cache)
voxcpm --text "..." --output out.wav \ voxcpm --text "..." --output out.wav \
--hf-model-id openbmb/VoxCPM-0.5B --cache-dir ~/.cache/huggingface --local-files-only --hf-model-id openbmb/VoxCPM1.5 --cache-dir ~/.cache/huggingface --local-files-only
# 6) Denoiser control # 6) Denoiser control
voxcpm --text "..." --output out.wav \ voxcpm --text "..." --output out.wav \
@@ -128,104 +175,51 @@ python -m voxcpm.cli --help
You can start the UI interface by running `python app.py`, which allows you to perform Voice Cloning and Voice Creation. You can start the UI interface by running `python app.py`, which allows you to perform Voice Cloning and Voice Creation.
### 5. Fine-tuning
VoxCPM1.5 supports both full fine-tuning (SFT) and LoRA fine-tuning, allowing you to train personalized voice models on your own data. See the [Fine-tuning Guide](docs/finetune.md) for detailed instructions.
## 👩‍🍳 A Voice Chef's Guide **Quick Start:**
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Lets begin. ```bash
# Full fine-tuning
python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
--- # LoRA fine-tuning
### 🥚 Step 1: Prepare Your Base Ingredients (Content) python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
```
First, choose how youd like to input your text:. ## 📚 Documentation
1. Regular Text (Classic Mode)
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
2. Phoneme Input (Native Mode)
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like {HH AH0 L OW1} (EN) or {ni3}{hao3} (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
---
### 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
This is the secret sauce that gives your audio its unique sound.
1. Cooking with a Prompt Speech (Following a Famous Recipe)
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
- For a Clean, Studio-Quality Voice:
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
2. Cooking au Naturel (Letting the Model Improvise)
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
- Pro Tip: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
---
### 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
- CFG Value (How Closely to Follow the Recipe)
- Default: A great starting point.
- Voice sounds strained or weird? Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
- Need maximum clarity and adherence to the text? Raise it slightly to keep the model on a tighter leash.
- Inference Timesteps (Simmering Time: Quality vs. Speed)
- Need a quick snack? Use a lower number. Perfect for fast drafts and experiments.
- Cooking a gourmet meal? Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
---
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!
- **[Usage Guide](docs/usage_guide.md)** - Detailed guide on how to use VoxCPM effectively, including text input modes, voice cloning tips, and parameter tuning
- **[Fine-tuning Guide](docs/finetune.md)** - Complete guide for fine-tuning VoxCPM models with SFT and LoRA
- **[Release Notes](docs/release_note.md)** - Version history and updates
- **[Performance Benchmarks](docs/performance.md)** - Detailed performance comparisons on public benchmarks
--- ---
## 📚 More Information
### 🌟 Community Projects
We're excited to see the VoxCPM community growing! Here are some amazing projects and features built by our community:
- **[ComfyUI-VoxCPM](https://github.com/wildminder/ComfyUI-VoxCPM)** A VoxCPM extension for ComfyUI.
- **[ComfyUI-VoxCPMTTS](https://github.com/1038lab/ComfyUI-VoxCPMTTS)** A VoxCPM extension for ComfyUI.
- **[WebUI-VoxCPM](https://github.com/rsxdalv/tts_webui_extension.vox_cpm)** A template extension for TTS WebUI.
- **[PR: Streaming API Support (by AbrahamSanders)](https://github.com/OpenBMB/VoxCPM/pull/26)**
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
## 📊 Performance Highlights *Note: The projects are not officially maintained by OpenBMB.*
VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
### Seed-TTS-eval Benchmark
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
### CV3-eval Benchmark
| Model | zh | en | hard-zh | | | hard-en | | |
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | - | - | - |
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |
*Have you built something cool with VoxCPM? We'd love to feature it here! Please open an issue or pull request to add your project.*
### 📊 Performance Highlights
VoxCPM achieves competitive results on public zero-shot TTS benchmarks. See [Performance Benchmarks](docs/performance.md) for detailed comparison tables.
@@ -236,6 +230,16 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
- Bilingual Model: VoxCPM is trained primarily on Chinese and English data. Performance on other languages is not guaranteed and may result in unpredictable or low-quality audio. - Bilingual Model: VoxCPM is trained primarily on Chinese and English data. Performance on other languages is not guaranteed and may result in unpredictable or low-quality audio.
- This model is released for research and development purposes only. We do not recommend its use in production or commercial applications without rigorous testing and safety evaluations. Please use VoxCPM responsibly. - This model is released for research and development purposes only. We do not recommend its use in production or commercial applications without rigorous testing and safety evaluations. Please use VoxCPM responsibly.
---
## 📝 TO-DO List
Please stay tuned for updates!
- [x] Release the VoxCPM technical report.
- [x] Support higher sampling rate (44.1kHz in VoxCPM-1.5).
- [x] Support SFT and LoRA fine-tuning.
- [] Multilingual Support (besides ZH/EN).
- [] Controllable Speech Generation by Human Instruction.
## 📄 License ## 📄 License
@@ -258,6 +262,8 @@ This project is developed by the following institutions:
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi) - <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
## ⭐ Star History
[![Star History Chart](https://api.star-history.com/svg?repos=OpenBMB/VoxCPM&type=Date)](https://star-history.com/#OpenBMB/VoxCPM&Date)
## 📚 Citation ## 📚 Citation
@@ -265,11 +271,10 @@ This project is developed by the following institutions:
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️! If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
```bib ```bib
@misc{voxcpm2025, @article{voxcpm2025,
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}}, title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning},
title = {{VoxCPM}}, author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
journal = {arXiv preprint arXiv:2509.24650},
year = {2025}, year = {2025},
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
note = {GitHub repository}
} }
``` ```

17
app.py
View File

@@ -8,7 +8,7 @@ from funasr import AutoModel
from pathlib import Path from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "": if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM-0.5B" os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
import voxcpm import voxcpm
@@ -29,7 +29,7 @@ class VoxCPMDemo:
# TTS model (lazy init) # TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM-0.5B" self.default_local_model_dir = "./models/VoxCPM1.5"
# ---------- Model helpers ---------- # ---------- Model helpers ----------
def _resolve_model_dir(self) -> str: def _resolve_model_dir(self) -> str:
@@ -108,7 +108,7 @@ class VoxCPMDemo:
normalize=do_normalize, normalize=do_normalize,
denoise=denoise, denoise=denoise,
) )
return (16000, wav) return (current_model.tts_model.sample_rate, wav)
# ---------- UI Builders ---------- # ---------- UI Builders ----------
@@ -170,7 +170,7 @@ def create_demo_interface(demo: VoxCPMDemo):
# Pro Tips # Pro Tips
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"): with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
gr.Markdown(f""" gr.Markdown("""
### Prompt Speech Enhancement参考语音降噪 ### Prompt Speech Enhancement参考语音降噪
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component. - **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。 **启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
**调低**:合成速度更快。 **调低**:合成速度更快。
- **Higher** for better synthesis quality. - **Higher** for better synthesis quality.
**调高**:合成质量更佳。 **调高**:合成质量更佳。
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
""") """)
# Main controls # Main controls
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
prompt_wav = gr.Audio( prompt_wav = gr.Audio(
sources=["upload", 'microphone'], sources=["upload", 'microphone'],
type="filepath", type="filepath",
label="Prompt Speech", label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav", value="./examples/example.wav",
) )
DoDenoisePromptAudio = gr.Checkbox( DoDenoisePromptAudio = gr.Checkbox(
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
text = gr.Textbox( text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.", value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text", label="Target Text",
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
) )
with gr.Row(): with gr.Row():
DoNormalizeText = gr.Checkbox( DoNormalizeText = gr.Checkbox(
value=False, value=False,
label="Text Normalization", label="Text Normalization",
elem_id="chk_normalize", elem_id="chk_normalize",
info="We use WeTextPorcessing library to normalize the input text." info="We use wetext library to normalize the input text."
) )
audio_output = gr.Audio(label="Output Audio") audio_output = gr.Audio(label="Output Audio")

BIN
assets/wechat.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

View File

@@ -0,0 +1,21 @@
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0

View File

@@ -0,0 +1,28 @@
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
alpha: 16
dropout: 0.0

View File

@@ -0,0 +1,21 @@
pretrained_path: /path/to/VoxCPM-0.5B/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 16000
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0

View File

@@ -0,0 +1,28 @@
pretrained_path: /path/to/VoxCPM-0.5B/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 16000
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
alpha: 16
dropout: 0.0

375
docs/finetune.md Normal file
View File

@@ -0,0 +1,375 @@
# VoxCPM Fine-tuning Guide
This guide covers how to fine-tune VoxCPM models with two approaches: full fine-tuning and LoRA fine-tuning.
### 🎓 SFT (Supervised Fine-Tuning)
Full fine-tuning updates all model parameters. Suitable for:
- 📊 Large, specialized datasets
- 🔄 Cases where significant behavior changes are needed
### ⚡ LoRA Fine-tuning
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
- 🎯 Trains only a small number of additional parameters
- 💾 Significantly reduces memory requirements and training time
- 🔀 Supports multiple LoRA adapters with hot-swapping
## Table of Contents
- [Data Preparation](#data-preparation)
- [Full Fine-tuning](#full-fine-tuning)
- [LoRA Fine-tuning](#lora-fine-tuning)
- [Inference](#inference)
- [LoRA Hot-swapping](#lora-hot-swapping)
- [FAQ](#faq)
---
## Data Preparation
Training data should be prepared as a JSONL manifest file, with one sample per line:
```jsonl
{"audio": "path/to/audio1.wav", "text": "Transcript of audio 1."}
{"audio": "path/to/audio2.wav", "text": "Transcript of audio 2."}
{"audio": "path/to/audio3.wav", "text": "Optional duration field.", "duration": 3.5}
{"audio": "path/to/audio4.wav", "text": "Optional dataset_id for multi-dataset.", "dataset_id": 1}
```
### Required Fields
| Field | Description |
|-------|-------------|
| `audio` | Path to audio file (absolute or relative) |
| `text` | Corresponding transcript |
### Optional Fields
| Field | Description |
|-------|-------------|
| `duration` | Audio duration in seconds (speeds up sample filtering) |
| `dataset_id` | Dataset ID for multi-dataset training (default: 0) |
### Requirements
- Audio format: WAV
- Sample rate: 16kHz for VoxCPM-0.5B, 44.1kHz for VoxCPM1.5
- Text: Transcript matching the audio content
See `examples/train_data_example.jsonl` for a complete example.
---
## Full Fine-tuning
Full fine-tuning updates all model parameters. Suitable for large datasets or when significant behavior changes are needed.
### Configuration
Create `conf/voxcpm_v1.5/voxcpm_finetune_all.yaml`:
```yaml
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: ""
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001 # Use smaller LR for full fine-tuning
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0
```
### Training
```bash
# Single GPU
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
# Multi-GPU
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
```
### Checkpoint Structure
Full fine-tuning saves a complete model directory that can be loaded directly:
```
checkpoints/finetune_all/
└── step_0002000/
├── model.safetensors # Model weights (excluding audio_vae)
├── config.json # Model config
├── audiovae.pth # Audio VAE weights
├── tokenizer.json # Tokenizer
├── tokenizer_config.json
├── special_tokens_map.json
├── optimizer.pth
└── scheduler.pth
```
---
## LoRA Fine-tuning
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that trains only a small number of additional parameters, significantly reducing memory requirements.
### Configuration
Create `conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml`:
```yaml
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: ""
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001 # LoRA can use larger LR
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
# LoRA configuration
lora:
enable_lm: true # Apply LoRA to Language Model
enable_dit: true # Apply LoRA to Diffusion Transformer
enable_proj: false # Apply LoRA to projection layers (optional)
r: 32 # LoRA rank (higher = more capacity)
alpha: 16 # LoRA alpha, scaling = alpha / r
dropout: 0.0
# Target modules
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
```
### LoRA Parameters
| Parameter | Description | Recommended |
|-----------|-------------|-------------|
| `enable_lm` | Apply LoRA to LM (language model) | `true` |
| `enable_dit` | Apply LoRA to DiT (diffusion model) | `true` (required for voice cloning) |
| `r` | LoRA rank (higher = more capacity) | 16-64 |
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
| `target_modules_*` | Layer names to add LoRA | attention layers |
### Training
```bash
# Single GPU
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
# Multi-GPU
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
```
### Checkpoint Structure
LoRA training saves only LoRA parameters:
```
checkpoints/finetune_lora/
└── step_0002000/
├── lora_weights.safetensors # Only lora_A, lora_B parameters
├── optimizer.pth
└── scheduler.pth
```
---
## Inference
### Full Fine-tuning Inference
The checkpoint directory is a complete model, load it directly:
```bash
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
--text "Hello, this is the fine-tuned model." \
--output output.wav
```
With voice cloning:
```bash
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio /path/to/reference.wav \
--prompt_text "Reference audio transcript" \
--output cloned_output.wav
```
### LoRA Inference
LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint:
```bash
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "Hello, this is LoRA fine-tuned result." \
--output lora_output.wav
```
With voice cloning:
```bash
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "This is voice cloning with LoRA." \
--prompt_audio /path/to/reference.wav \
--prompt_text "Reference audio transcript" \
--output cloned_output.wav
```
---
## LoRA Hot-swapping
LoRA supports dynamic loading, unloading, and switching at inference time without reloading the entire model.
### API Reference
```python
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# 1. Load model with LoRA structure and weights
lora_cfg = LoRAConfig(
enable_lm=True,
enable_dit=True,
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5", # or local path
load_denoiser=False, # Optional: disable denoiser for faster loading
optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg,
lora_weights_path="/path/to/lora_checkpoint",
)
# 2. Generate audio
audio = model.generate(
text="Hello, this is LoRA fine-tuned result.",
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
prompt_text="Reference audio transcript", # Optional: for voice cloning
)
# 3. Disable LoRA (use base model only)
model.set_lora_enabled(False)
# 4. Re-enable LoRA
model.set_lora_enabled(True)
# 5. Unload LoRA (reset weights to zero)
model.unload_lora()
# 6. Hot-swap to another LoRA
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 7. Get current LoRA weights
lora_state = model.get_lora_state_dict()
```
### Simplified Usage (Auto LoRA Config)
If you only have LoRA weights and don't need custom config, just provide the path:
```python
from voxcpm.core import VoxCPM
# Auto-create default LoRAConfig when only lora_weights_path is provided
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5",
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
)
```
### Method Reference
| Method | Description | torch.compile Compatible |
|--------|-------------|--------------------------|
| `load_lora(path)` | Load LoRA weights from file | ✅ |
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
---
## FAQ
### 1. Out of Memory (OOM)
- Increase `grad_accum_steps` (gradient accumulation)
- Decrease `batch_size`
- Use LoRA fine-tuning instead of full fine-tuning
- Decrease `max_batch_tokens` to filter long samples
### 2. Poor LoRA Performance
- Increase `r` (LoRA rank)
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
- Increase training steps
- Add more target modules
### 3. Training Not Converging
- Decrease `learning_rate`
- Increase `warmup_steps`
- Check data quality
### 4. LoRA Not Taking Effect at Inference
- Ensure inference config matches training config LoRA parameters
- Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called
### 5. Checkpoint Loading Errors
- Full fine-tuning: checkpoint directory should contain `model.safetensors`(or `pytorch_model.bin`), `config.json`, `audiovae.pth`
- LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`)

46
docs/performance.md Normal file
View File

@@ -0,0 +1,46 @@
# 📊 Performance Highlights
VoxCPM achieves competitive results on public zero-shot TTS benchmarks.
## Seed-TTS-eval Benchmark
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
| MaskGCT | 1B | ✅ | 2.62 | 71.7 | 2.27 | 77.4 | - | - |
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | 23.37 | 64.3 |
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | 7.12 | 75.5 |
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | 55.07 | 65.6 |
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
## CV3-eval Benchmark
| Model | zh | en | hard-zh | | | hard-en | | |
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | 8.78 | 74.5 | 3.80 |
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |

109
docs/release_note.md Normal file
View File

@@ -0,0 +1,109 @@
# VoxCPM1.5 Release Notes
**Release Date:** December 5, 2025
## 🎉 Overview
Were thrilled to introduce a major upgrade that improves audio quality and efficiency of VoxCPM, while maintaining the core capabilities of context-aware speech generation and zero-shot voice cloning.
| Feature | VoxCPM | VoxCPM1.5 |
|---------|------------|------------|
| **Audio VAE Sampling Rate** | 16kHz | 44.1kHz |
| **LM Token Rate** | 12.5Hz | 6.25Hz |
| **Patch Size** | 2 | 4 |
| **SFT Support** | ✅ | ✅ |
| **LoRA Support** | ✅ | ✅ |
## 🎵 Model Updates
### 🔊 AudioVAE Sampling Rate: 16kHz → 44.1kHz
The AudioVAE now supports 44.1kHz sampling rate, which allows the model to:
- 🎯 Clone better, preserving more high-frequency details and generate higher quality voice outputs
*Note: This upgrade enables higher quality generation when using high-quality reference audio, but does not guarantee that all generated audio will be high-fidelity. The output quality depends on the **prompt speech** quality.*
### ⚡ Token Rate: 12.5Hz → 6.25Hz
We reduced the token rate in LM backbone from 12.5Hz to 6.25Hz (LocEnc&LocDiT patch size increased from 2 to 4) while maintaining similar performance on evaluation benchmarks. This change:
- 💨 Reduces computational requirements for generating the same length of audio
- 📈 Provides a foundation for longer audio generation
- 🏗️ Paves the way for training larger models in the future
## 🔧 Fine-tuning Support
We support full fine-tuning and LoRA fine-tuning now, please see the [Fine-tuning Guide](finetune.md) for detailed instructions.
## 📚 Documentation
- Updated README with version comparison
- Added comprehensive fine-tuning guide
- Improved code comments and documentation
## 🙏 Our Thanks to You
This release wouldnt be possible without the incredible feedback, testing, and contributions from our open-source community. Thank you for helping shape VoxCPM1.5!
## 📞 Let's Build Together
Questions, ideas, or want to contribute?
- 🐛 Report an issue: [GitHub Issues on OpenBMB/VoxCPM](https://github.com/OpenBMB/VoxCPM/issues)
- 📖 Dig into the docs: Check the [docs/](../docs/) folder for guides and API details
Enjoy the richer sound and powerful new features of VoxCPM1.5 🎉
We can't wait to hear what you create next! 🥂
## 🚀 What We're Working On
We're continuously improving VoxCPM and working on exciting new features:
- 🌍 **Multilingual TTS Support**: We are actively developing support for languages beyond Chinese and English.
- 🎯 **Controllable Expressive Speech Generation**: We are researching controllable speech generation that allows fine-grained control over speech attributes (emotion, timbre, prosody, etc.) through natural language instructions.
- 🎵 **Universal Audio Generation Foundation**: We also hope to explore VoxCPM as a unified audio generation foundation model capable of joint generation of speech, music, and sound effects. However, this is a longer-term vision.
**📅 Next Release**: We plan to release the next version in Q1 2026, which will include significant improvements and new features. Stay tuned for updates! We're committed to making VoxCPM even more powerful and versatile.
## ❓ Frequently Asked Questions (FAQ)
### Q: Does VoxCPM support fine-tuning for personalized voice customization?
**A:** Yes! VoxCPM now supports both full fine-tuning (SFT) and efficient LoRA fine-tuning. You can train personalized voice models on your own data. Please refer to the [Fine-tuning Guide](finetune.md) for detailed instructions and examples.
### Q: Is 16kHz audio quality sufficient for my use case?
**A:** We have upgraded the AudioVAE to support 44.1kHz sampling rate in VoxCPM1.5, which provides higher quality audio output with better preservation of high-frequency details. This upgrade enables better voice cloning quality and more natural speech synthesis when using high-quality reference audio.
### Q: Has the stability issue been resolved?
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the training data and model architecture. Based on community feedback, we collected some stability issues such as:
- Increased noise and reverberation
- Audio artifacts (e.g., howling/squealing)
- Unstable speaking rate (speeding up)
- Volume fluctuations (increases or decreases)
- Noise artifacts at the beginning and end of audio
- Synthesis issues with very short texts (e.g., "hello")
While we have made improvements to these issues, they have not been completely resolved and may still occasionally occur, especially with very long or highly expressive inputs. We continue to work on further stability improvements in future versions.
### Q: Does VoxCPM plan to support multilingual TTS?
**A:** Currently, VoxCPM is primarily trained on Chinese and English data. We are actively researching and developing multilingual TTS support for more languages beyond Chinese and English. Please let us know what languages you'd like to see supported!
### Q: Does VoxCPM plan to support controllable generation (emotion, style, fine-grained control)?
**A:** Currently, VoxCPM only supports zero-shot voice cloning and context-aware speech generation. Direct control over specific speech attributes (emotion, style, fine-grained prosody) is limited. However, we are actively researching instruction-controllable expressive speech generation with fine-grained control capabilities, working towards a human instruction-to-speech generation model!
### Q: Does VoxCPM support different hardware chips (e.g., Ascend 910B, XPU, NPU)?
**A:** Currently, we have not yet adapted VoxCPM for different hardware chips. Our main focus remains on developing new model capabilities and improving stability. We encourage you to check if community developers have done similar work, and we warmly welcome everyone to contribute and promote such adaptations together!
These features are under active development, and we look forward to sharing updates in future releases!

53
docs/usage_guide.md Normal file
View File

@@ -0,0 +1,53 @@
# 👩‍🍳 A Voice Chef's Guide
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Let's begin.
---
## 🥚 Step 1: Prepare Your Base Ingredients (Content)
First, choose how you'd like to input your text:
### 1. Regular Text (Classic Mode)
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
### 2. Phoneme Input (Native Mode)
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like `{HH AH0 L OW1}` (EN) or `{ni3}{hao3}` (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
- **Phoneme Conversion**: For Chinese, phonemes are converted using pinyin. For English, phonemes are converted using CMUDict. Please refer to the relevant documentation for more details.
---
## 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
This is the secret sauce that gives your audio its unique sound.
### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
- **For a Clean, Studio-Quality Voice:**
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
### 2. Cooking au Naturel (Letting the Model Improvise)
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
- **Pro Tip**: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
---
## 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
### CFG Value (How Closely to Follow the Recipe)
- **Default**: A great starting point.
- **Voice sounds strained or weird?** Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
- **Need maximum clarity and adherence to the text?** Raise it slightly to keep the model on a tighter leash.
- **Short sentences?** Consider increasing the CFG value for better clarity and adherence.
- **Long texts?** Consider lowering the CFG value to improve stability and naturalness over extended passages.
### Inference Timesteps (Simmering Time: Quality vs. Speed)
- **Need a quick snack?** Use a lower number. Perfect for fast drafts and experiments.
- **Cooking a gourmet meal?** Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
---
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!

View File

@@ -0,0 +1,6 @@
{"audio": "examples/example.wav", "text": "This is an example audio transcript for training."}
{"audio": "/absolute/path/to/audio1.wav", "text": "You can use absolute paths for audio files."}
{"audio": "relative/path/to/audio2.wav", "text": "Or relative paths from the working directory."}
{"audio": "data/audio3.wav", "text": "Each line is a JSON object with audio path and text.", "duration": 3.5}
{"audio": "data/audio4.wav", "text": "Optional: add duration field to skip audio loading during filtering.", "duration": 2.8}
{"audio": "data/audio5.wav", "text": "Optional: add dataset_id for multi-dataset training.", "dataset_id": 1}

View File

@@ -20,23 +20,21 @@ classifiers = [
"Intended Audience :: Developers", "Intended Audience :: Developers",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
] ]
requires-python = ">=3.8" requires-python = ">=3.10"
dependencies = [ dependencies = [
"torch>=2.5.0", "torch>=2.5.0",
"torchaudio>=2.5.0", "torchaudio>=2.5.0",
"transformers>=4.36.2", "transformers>=4.36.2",
"einops", "einops",
"gradio", "gradio<6",
"inflect", "inflect",
"addict", "addict",
"WeTextProcessing", "wetext",
"modelscope>=1.22.0", "modelscope>=1.22.0",
"datasets>=2,<4", "datasets>=3,<4",
"huggingface-hub", "huggingface-hub",
"pydantic", "pydantic",
"tqdm", "tqdm",
@@ -44,7 +42,10 @@ dependencies = [
"sortedcontainers", "sortedcontainers",
"soundfile", "soundfile",
"funasr", "funasr",
"spaces" "spaces",
"argbind",
"safetensors"
] ]
[project.optional-dependencies] [project.optional-dependencies]
@@ -78,7 +79,7 @@ version_scheme = "post-release"
[tool.black] [tool.black]
line-length = 120 line-length = 120
target-version = ['py38'] target-version = ['py310']
include = '\.pyi?$' include = '\.pyi?$'
extend-exclude = ''' extend-exclude = '''
/( /(

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
"""
Full finetune inference script (no LoRA).
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
can be loaded directly via VoxCPM.
Usage:
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/step_0001000 \
--text "Hello, I am the finetuned VoxCPM." \
--output ft_test.wav
With voice cloning:
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/step_0001000 \
--text "Hello, this is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output ft_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
def parse_args():
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
parser.add_argument(
"--ckpt_dir",
type=str,
required=True,
help="Checkpoint directory (contains pytorch_model.bin, config.json, audiovae.pth, etc.)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Target text to synthesize",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="Optional: reference audio path for voice cloning",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="Optional: transcript of reference audio",
)
parser.add_argument(
"--output",
type=str,
default="ft_test.wav",
help="Output wav file path",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="CFG scale (default: 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="Diffusion inference steps (default: 10)",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
def main():
args = parse_args()
# Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir,
load_denoiser=False,
optimize=True,
)
# Run inference
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'")
if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
print(f"[FT Inference] Reference text: {prompt_text}")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
# Save audio
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,225 @@
#!/usr/bin/env python3
"""
LoRA inference test script.
Usage:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "Hello, this is LoRA finetuned result." \
--output lora_test.wav
With voice cloning:
python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output lora_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Training YAML config path (contains pretrained_path and lora config)",
)
parser.add_argument(
"--lora_ckpt",
type=str,
required=True,
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Target text to synthesize",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="Optional: reference audio path for voice cloning",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="Optional: transcript of reference audio",
)
parser.add_argument(
"--output",
type=str,
default="lora_test.wav",
help="Output wav file path",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="CFG scale (default: 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="Diffusion inference steps (default: 10)",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. Load YAML config
cfg = load_yaml_config(args.config_path)
pretrained_path = cfg["pretrained_path"]
lora_cfg_dict = cfg.get("lora", {}) or {}
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
# 2. Check LoRA checkpoint
ckpt_dir = args.lora_ckpt
if not Path(ckpt_dir).exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
# 3. Load model with LoRA (no denoiser)
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=True,
lora_config=lora_cfg,
lora_weights_path=ckpt_dir,
)
# 4. Synthesize audio
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...")
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
model.unload_lora()
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, skipped = model.load_lora(str(ckpt_dir))
print(f" Reloaded {len(loaded)} parameters")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_length=args.max_len,
normalize=args.normalize,
denoise=False,
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,362 @@
#!/usr/bin/env python3
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
import contextlib
from typing import Dict, Optional
import argbind
import torch
from tensorboardX import SummaryWriter
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format")
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
BatchProcessor,
TrainingTracker,
build_dataloader,
load_audio_text_datasets,
)
@argbind.bind(without_prefix=True)
def train(
pretrained_path: str,
train_manifest: str,
val_manifest: str = "",
sample_rate: int = 16_000,
batch_size: int = 1,
grad_accum_steps: int = 1,
num_workers: int = 2,
num_iters: int = 100_000,
log_interval: int = 100,
valid_interval: int = 1_000,
save_interval: int = 10_000,
learning_rate: float = 1e-4,
weight_decay: float = 1e-2,
warmup_steps: int = 1_000,
max_steps: int = 100_000,
max_batch_tokens: int = 0,
save_path: str = "checkpoints",
tensorboard: str = "",
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None,
config_path: str = "",
):
_ = config_path
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs"
# Only create directories on rank 0 to avoid race conditions
if accelerator.rank == 0:
save_dir.mkdir(parents=True, exist_ok=True)
tb_dir.mkdir(parents=True, exist_ok=True)
accelerator.barrier() # Wait for directory creation
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
train_manifest=train_manifest,
val_manifest=val_manifest,
sample_rate=sample_rate,
)
def tokenize(batch):
text_list = batch["text"]
text_ids = [tokenizer(text) for text in text_list]
return {"text_ids": text_ids}
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
if val_ds is not None:
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
num_train_samples = len(train_ds)
# ------------------------------------------------------------------ #
# Optional: filter samples by estimated token count to avoid OOM
# Enabled when max_batch_tokens > 0:
# max_sample_len = max_batch_tokens // batch_size
# Samples exceeding this length will be dropped
# ------------------------------------------------------------------ #
if max_batch_tokens and max_batch_tokens > 0:
from voxcpm.training.data import compute_sample_lengths
audio_vae_fps = base_model.audio_vae.sample_rate / base_model.audio_vae.hop_length
est_lengths = compute_sample_lengths(
train_ds,
audio_vae_fps=audio_vae_fps,
patch_size=base_model.config.patch_size,
)
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths)
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len]
if len(keep_indices) < len(train_ds) and accelerator.rank == 0:
tracker.print(
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} "
f"training samples longer than {max_sample_len} tokens "
f"(max_batch_tokens={max_batch_tokens})."
)
train_ds = train_ds.select(keep_indices)
train_loader = build_dataloader(
train_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
)
val_loader = (
build_dataloader(
val_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
if val_ds is not None
else None
)
batch_processor = BatchProcessor(
config=base_model.config,
audio_vae=base_model.audio_vae,
dataset_cnt=dataset_cnt,
device=accelerator.device,
)
del base_model.audio_vae
model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
print(name, param.requires_grad)
optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=learning_rate,
weight_decay=weight_decay,
)
# Cosine + warmup scheduler from transformers:
# - num_warmup_steps: warmup steps
# - num_training_steps: total training steps (outer step count)
total_training_steps = max_steps if max_steps > 0 else num_iters
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_training_steps,
)
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
grad_accum_steps = max(int(grad_accum_steps), 1)
data_epoch = 0
train_iter = iter(train_loader)
def get_next_batch():
"""Get next batch, handles epoch boundary and DistributedSampler."""
nonlocal train_iter, data_epoch
try:
return next(train_iter)
except StopIteration:
data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None)
if hasattr(sampler, 'set_epoch'):
sampler.set_epoch(data_epoch)
train_iter = iter(train_loader)
return next(train_iter)
with tracker.live():
for step in range(num_iters):
tracker.step = step
optimizer.zero_grad(set_to_none=True)
# Gradient accumulation: accumulate gradients over micro-batches before optimizer step
loss_dict = {}
for micro_step in range(grad_accum_steps):
batch = get_next_batch()
processed = batch_processor(batch)
# Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1)
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context:
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=step / max(1, num_iters),
)
total_loss = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
weight = lambdas.get(key, 1.0)
loss_value = value * weight / grad_accum_steps
total_loss = total_loss + loss_value
# Record raw loss from last micro-batch for logging
loss_dict[key] = value.detach()
# Accumulate gradients (normalized by grad_accum_steps)
accelerator.backward(total_loss)
# After all micro-batches, do unscale / grad_norm / step
scaler = getattr(accelerator, "scaler", None)
if scaler is not None:
scaler.unscale_(optimizer)
# Use large max_norm to only compute grad_norm without actual clipping
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
accelerator.step(optimizer)
accelerator.update()
scheduler.step()
if step % log_interval == 0:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
loss_values["epoch"] = float(epoch)
loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and step % valid_interval == 0 and step != 0:
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if step % save_interval == 0 and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
if accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
if writer:
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.eval()
losses = []
num_batches = 0
max_val_batches = 10
with torch.no_grad():
for batch in val_loader:
if num_batches >= max_val_batches:
break
processed = batch_processor(batch)
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=0.0,
sample_generate=False,
)
total = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
total += lambdas.get(key, 1.0) * value
losses.append(total.detach())
num_batches += 1
if losses:
mean_loss = torch.stack(losses).mean()
# All-reduce validation loss across processes for global average
accelerator.all_reduce(mean_loss)
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
model.train()
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
tag = "latest" if step == 0 else f"step_{step:07d}"
folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True)
unwrapped = model.module if hasattr(model, "module") else model
full_state = unwrapped.state_dict()
lora_cfg = unwrapped.lora_config
if lora_cfg is not None:
# LoRA finetune: save only lora_A/lora_B weights
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
if SAFETENSORS_AVAILABLE:
save_file(state_dict, folder / "lora_weights.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
else:
# Full finetune: save non-vae weights to model.safetensors
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
if SAFETENSORS_AVAILABLE:
save_file(state_dict, folder / "model.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
# Copy config files from pretrained path
if pretrained_path:
pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
for fname in files_to_copy:
src = pretrained_dir / fname
if src.exists():
shutil.copy2(src, folder / fname)
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config
args = argbind.parse_args()
config_file = args.get("config_path")
# If YAML config provided, use YAML args to call train
if config_file:
yaml_args = load_yaml_config(config_file)
train(**yaml_args)
else:
# Otherwise use command line args (parsed by argbind)
with argbind.scope(args):
train()

View File

@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
"ZIPENHANCER_MODEL_PATH", None "ZIPENHANCER_MODEL_PATH", None
) )
# Build LoRA config if lora_path is provided
lora_config = None
lora_weights_path = getattr(args, "lora_path", None)
if lora_weights_path:
from voxcpm.model.voxcpm import LoRAConfig
lora_config = LoRAConfig(
enable_lm=getattr(args, "lora_enable_lm", True),
enable_dit=getattr(args, "lora_enable_dit", True),
enable_proj=getattr(args, "lora_enable_proj", False),
r=getattr(args, "lora_r", 32),
alpha=getattr(args, "lora_alpha", 16),
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
# Load from local path if provided # Load from local path if provided
if getattr(args, "model_path", None): if getattr(args, "model_path", None):
try: try:
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
voxcpm_model_path=args.model_path, voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path, zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False), enable_denoiser=not getattr(args, "no_denoiser", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
) )
print("Model loaded (local).") print("Model loaded (local).")
return model return model
@@ -69,11 +87,13 @@ def load_model(args) -> VoxCPM:
# Otherwise, try from_pretrained (Hub); exit on failure # Otherwise, try from_pretrained (Hub); exit on failure
try: try:
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"), hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
load_denoiser=not getattr(args, "no_denoiser", False), load_denoiser=not getattr(args, "no_denoiser", False),
zipenhancer_model_id=zipenhancer_path, zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None), cache_dir=getattr(args, "cache_dir", None),
local_files_only=getattr(args, "local_files_only", False), local_files_only=getattr(args, "local_files_only", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
) )
print("Model loaded (from_pretrained).") print("Model loaded (from_pretrained).")
return model return model
@@ -120,11 +140,11 @@ def cmd_clone(args):
) )
# Save audio # Save audio
sf.write(str(output_path), audio_array, 16000) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}") print(f"Saved audio to: {output_path}")
# Stats # Stats
duration = len(audio_array) / 16000 duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s") print(f"Duration: {duration:.2f}s")
@@ -152,11 +172,11 @@ def cmd_synthesize(args):
) )
# Save audio # Save audio
sf.write(str(output_path), audio_array, 16000) sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}") print(f"Saved audio to: {output_path}")
# Stats # Stats
duration = len(audio_array) / 16000 duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s") print(f"Duration: {duration:.2f}s")
@@ -198,9 +218,9 @@ def cmd_batch(args):
denoise=args.denoise and prompt_audio_path is not None denoise=args.denoise and prompt_audio_path is not None
) )
output_file = output_dir / f"output_{i:03d}.wav" output_file = output_dir / f"output_{i:03d}.wav"
sf.write(str(output_file), audio_array, 16000) sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / 16000 duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)") print(f" Saved: {output_file} ({duration:.2f}s)")
success_count += 1 success_count += 1
@@ -240,6 +260,7 @@ Examples:
# Prompt audio (for voice cloning) # Prompt audio (for voice cloning)
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path") parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio") parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)") parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
# Generation parameters # Generation parameters
@@ -249,12 +270,21 @@ Examples:
# Model loading parameters # Model loading parameters
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)") parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)") parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads") parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)") parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)") parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
# LoRA parameters
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
return parser return parser
@@ -279,6 +309,12 @@ def main():
# If prompt audio+text provided → voice cloning # If prompt audio+text provided → voice cloning
if args.prompt_audio or args.prompt_text: if args.prompt_audio or args.prompt_text:
if not args.prompt_text and args.prompt_file:
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
with open(args.prompt_file, 'r', encoding='utf-8') as f:
args.prompt_text = f.read()
if not args.prompt_audio or not args.prompt_text: if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text") print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
sys.exit(1) sys.exit(1)

View File

@@ -1,17 +1,19 @@
import torch
import torchaudio
import os import os
import re
import tempfile import tempfile
import numpy as np
from typing import Generator, Optional
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel from .model.voxcpm import VoxCPMModel, LoRAConfig
from .utils.text_normalize import TextNormalizer
class VoxCPM: class VoxCPM:
def __init__(self, def __init__(self,
voxcpm_model_path : str, voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True, enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
): ):
"""Initialize VoxCPM TTS pipeline. """Initialize VoxCPM TTS pipeline.
@@ -22,10 +24,32 @@ class VoxCPM:
zipenhancer_model_path: ModelScope acoustic noise suppression model zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized. id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline. enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
""" """
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
self.text_normalizer = TextNormalizer() # If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
enable_lm=True,
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}")
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None: if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path) self.denoiser = ZipEnhancer(zipenhancer_model_path)
@@ -33,27 +57,41 @@ class VoxCPM:
self.denoiser = None self.denoiser = None
print("Warm up VoxCPMModel...") print("Warm up VoxCPMModel...")
self.tts_model.generate( self.tts_model.generate(
target_text="Hello, this is the first test sentence." target_text="Hello, this is the first test sentence.",
) max_len=10,
)
@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM-0.5B", hf_model_id: str = "openbmb/VoxCPM1.5",
load_denoiser: bool = True, load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base", zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None, cache_dir: str = None,
local_files_only: bool = False, local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
): ):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args: Args:
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path. hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
load_denoiser: Whether to initialize the denoiser pipeline. load_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
zipenhancer_model_id: Denoiser model id or path for ModelScope zipenhancer_model_id: Denoiser model id or path for ModelScope
acoustic noise suppression. acoustic noise suppression.
cache_dir: Custom cache directory for the snapshot. cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt local_files_only: If True, only use local files and do not attempt
to download. to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
after model initialization.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
Returns: Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -82,21 +120,33 @@ class VoxCPM:
voxcpm_model_path=local_path, voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser, enable_denoiser=load_denoiser,
optimize=optimize,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,
) )
def generate(self, def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str, text : str,
prompt_wav_path : str = None, prompt_wav_path : str = None,
prompt_text : str = None, prompt_text : str = None,
cfg_value : float = 2.0, cfg_value : float = 2.0,
inference_timesteps : int = 10, inference_timesteps : int = 10,
max_length : int = 4096, min_len : int = 2,
normalize : bool = True, max_len : int = 4096,
denoise : bool = True, normalize : bool = False,
denoise : bool = False,
retry_badcase : bool = True, retry_badcase : bool = True,
retry_badcase_max_times : int = 3, retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0, retry_badcase_ratio_threshold : float = 6.0,
): streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform. """Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external This method optionally builds and reuses a prompt cache. If an external
@@ -111,20 +161,32 @@ class VoxCPM:
prompt_text: Text content corresponding to the prompt audio. prompt_text: Text content corresponding to the prompt audio.
cfg_value: Guidance scale for the generation model. cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps. inference_timesteps: Number of inference steps.
max_length: Maximum token length during generation. max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation. normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is denoise: Whether to denoise the prompt audio if a denoiser is
available. available.
retry_badcase: Whether to retry badcase. retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase. retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio. retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns: Returns:
numpy.ndarray: 1D waveform array (float32) on CPU. Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
otherwise yields a single array containing the final audio.
""" """
texts = text.split("\n") if not text.strip() or not isinstance(text, str):
texts = [t.strip() for t in texts if t.strip()] raise ValueError("target text must be a non-empty string")
final_wav = []
temp_prompt_wav_path = None if prompt_wav_path is not None:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
temp_prompt_wav_path = None
try: try:
if prompt_wav_path is not None and prompt_text is not None: if prompt_wav_path is not None and prompt_text is not None:
@@ -140,36 +202,79 @@ class VoxCPM:
else: else:
fixed_prompt_cache = None # will be built from the first inference fixed_prompt_cache = None # will be built from the first inference
for sub_text in texts: if normalize:
if sub_text.strip() == "": if self.text_normalizer is None:
continue from .utils.text_normalize import TextNormalizer
print("sub_text:", sub_text) self.text_normalizer = TextNormalizer()
if normalize: text = self.text_normalizer.normalize(text)
sub_text = self.text_normalizer.normalize(sub_text)
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache( generate_result = self.tts_model._generate_with_prompt_cache(
target_text=sub_text, target_text=text,
prompt_cache=fixed_prompt_cache, prompt_cache=fixed_prompt_cache,
min_len=2, min_len=min_len,
max_len=max_length, max_len=max_len,
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
retry_badcase=retry_badcase, retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times, retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
) streaming=streaming,
if fixed_prompt_cache is None: )
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
original_cache=None,
new_text_token=target_text_token,
new_audio_feat=generated_audio_feat
)
final_wav.append(wav)
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy() for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally: finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path): if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
try: try:
os.unlink(temp_prompt_wav_path) os.unlink(temp_prompt_wav_path)
except OSError: except OSError:
pass pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
if self.tts_model.lora_config is None:
raise RuntimeError(
"Cannot load LoRA weights: model was not initialized with LoRA config. "
"Please reinitialize with lora_config or lora_weights_path parameter."
)
return self.tts_model.load_lora_weights(lora_weights_path)
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
return self.tts_model.lora_config is not None

View File

@@ -19,18 +19,27 @@ limitations under the License.
""" """
import os import os
from typing import Dict, Optional, Tuple, Union from typing import Tuple, Union, Generator, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torchaudio import torchaudio
import warnings
from einops import rearrange from einops import rearrange
from pydantic import BaseModel from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE from ..modules.audiovae import AudioVAE, AudioVAEConfig
from ..modules.layers import ScalarQuantizationLayer from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
@@ -65,10 +74,31 @@ class VoxCPMConfig(BaseModel):
encoder_config: VoxCPMEncoderConfig encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig dit_config: VoxCPMDitConfig
audio_vae_config: Optional[AudioVAEConfig] = None
max_length: int = 4096 max_length: int = 4096
device: str = "cuda" device: str = "cuda"
dtype: str = "bfloat16" dtype: str = "bfloat16"
dit_mean_mode: bool = False
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
dropout: float = 0.0
# Target linear layer names for LM & DiT (matched by attribute name)
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
# Projection layer attribute names to find on VoxCPMModel
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPMModel(nn.Module): class VoxCPMModel(nn.Module):
@@ -77,18 +107,24 @@ class VoxCPMModel(nn.Module):
config: VoxCPMConfig, config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast, tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE, audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim self.feat_dim = config.feat_dim
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.device = config.device self.device = config.device
if not torch.cuda.is_available(): if not torch.cuda.is_available():
self.device = "cpu" if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
# Text-Semantic LM # Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config) self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101 self.audio_start_token = 101
@@ -99,7 +135,7 @@ class VoxCPMModel(nn.Module):
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0 residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config) self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder # Local Encoder
encoder_config = config.lm_config.model_copy(deep=True) encoder_config = config.lm_config.model_copy(deep=True)
@@ -123,6 +159,7 @@ class VoxCPMModel(nn.Module):
in_channels=config.feat_dim, in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config, cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim), estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_mean_mode,
) )
# Projection layers # Projection layers
@@ -131,7 +168,7 @@ class VoxCPMModel(nn.Module):
config.lm_config.hidden_size, config.lm_config.hidden_size,
config.scalar_quantization_latent_dim, config.scalar_quantization_latent_dim,
config.scalar_quantization_scale config.scalar_quantization_scale
) )
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size) self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -140,29 +177,168 @@ class VoxCPMModel(nn.Module):
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size) self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU() self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False) self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE # Audio VAE
self.audio_vae = audio_vae self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate self.sample_rate = audio_vae.sample_rate
if self.lora_config is not None:
def optimize(self): self._apply_lora()
if self.device == "cuda":
def _apply_lora(self):
"""注入 LoRA 到 LM / DiT / 投影层"""
cfg = self.lora_config
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
)
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
def optimize(self, disable: bool = False):
if disable:
return self
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
else: except Exception as e:
self.base_lm.forward_step = self.base_lm.forward_step print(f"Warning: torch.compile disabled - {e}")
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
return self return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def generate( def _generate(
self, self,
target_text: str, target_text: str,
prompt_text: str = "", prompt_text: str = "",
@@ -174,7 +350,11 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False, retry_badcase: bool = False,
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection) retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
): streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if len(prompt_wav_path) == 0: if len(prompt_wav_path) == 0:
text = target_text text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text)) text_token = torch.LongTensor(self.text_tokenizer(text))
@@ -213,25 +393,25 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path) audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1: if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True) audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate: if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate) audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0: if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len)) # 左填充:在音频开头填充,保持有效音频数据在序列末尾
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# (B, D, T) # (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu() audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view( audio_feat = audio_feat.view(
self.audio_vae.latent_dim, self.audio_vae.latent_dim,
-1, -1,
self.patch_size, self.patch_size,
).permute(1, 2, 0) ).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0) audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device) text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token]) text_token = torch.cat([text_token, text_pad_token])
@@ -250,33 +430,46 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device) text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device) audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text)) target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0 retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times: while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference( inference_result = self._inference(
text_token, text_token,
text_mask, text_mask,
audio_feat, audio_feat,
audio_mask, audio_mask,
min_len=min_len, min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming,
) )
if retry_badcase: if streaming:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: patch_len = self.patch_size * self.chunk_size
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") for latent_pred, _ in inference_result:
retry_badcase_times += 1 decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
continue decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
else: yield decode_audio
break
else:
break break
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield decode_audio
@torch.inference_mode() @torch.inference_mode()
def build_prompt_cache( def build_prompt_cache(
@@ -292,13 +485,11 @@ class VoxCPMModel(nn.Module):
prompt_wav_path: prompt audio path (required) prompt_wav_path: prompt audio path (required)
Returns: Returns:
prompt_cache: dict with text tokens and audio features prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
""" """
if not prompt_text or not prompt_wav_path: if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required") raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio # load audio
audio, sr = torchaudio.load(prompt_wav_path) audio, sr = torchaudio.load(prompt_wav_path)
@@ -311,20 +502,21 @@ class VoxCPMModel(nn.Module):
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0: if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len)) # Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# extract audio features # extract audio features
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu() audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view( audio_feat = audio_feat.view(
self.audio_vae.latent_dim, self.audio_vae.latent_dim,
-1, -1,
self.patch_size, self.patch_size,
).permute(1, 2, 0) # (D, T, P) ).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token # build prompt cache - only save raw text and audio features
# build prompt cache
prompt_cache = { prompt_cache = {
"text_token": text_token, "prompt_text": prompt_text,
"audio_feat": audio_feat, "audio_feat": audio_feat,
} }
@@ -334,7 +526,7 @@ class VoxCPMModel(nn.Module):
def merge_prompt_cache( def merge_prompt_cache(
self, self,
original_cache: dict, original_cache: dict,
new_text_token: torch.Tensor, new_text: str,
new_audio_feat: torch.Tensor, new_audio_feat: torch.Tensor,
): ):
""" """
@@ -342,32 +534,44 @@ class VoxCPMModel(nn.Module):
Args: Args:
original_cache: original prompt cache original_cache: original prompt cache
new_text_token: newly generated text tokens new_text: newly generated text
new_audio_feat: newly generated audio features new_audio_feat: newly generated audio features
Returns: Returns:
merged_cache: merged cache merged_cache: merged cache with prompt_text and audio_feat
""" """
if original_cache is None: if original_cache is None:
return { return {
"text_token": new_text_token, "prompt_text": new_text,
"audio_feat": new_audio_feat, "audio_feat": new_audio_feat,
} }
original_text_token = original_cache["text_token"] original_prompt_text = original_cache["prompt_text"]
original_audio_feat = original_cache["audio_feat"] original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0) # Merge text by concatenation
merged_prompt_text = original_prompt_text + new_text
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0) merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache # build new cache
merged_cache = { merged_cache = {
"text_token": merged_text_token, "prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat, "audio_feat": merged_audio_feat,
} }
return merged_cache return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def generate_with_prompt_cache( def _generate_with_prompt_cache(
self, self,
target_text: str, target_text: str,
prompt_cache: dict, prompt_cache: dict,
@@ -378,7 +582,8 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False, retry_badcase: bool = False,
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, retry_badcase_ratio_threshold: float = 6.0,
): streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
""" """
Generate audio using pre-built prompt cache. Generate audio using pre-built prompt cache.
@@ -392,20 +597,27 @@ class VoxCPMModel(nn.Module):
retry_badcase: Whether to retry on bad cases retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
Returns: Returns:
tuple: (decoded audio tensor, new text tokens, new audio features) Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
""" """
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# get prompt from cache # get prompt from cache
if prompt_cache is None: if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32) prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
text = target_text
else: else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"] prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens prompt_text = prompt_cache["prompt_text"]
target_text_token = torch.LongTensor(self.text_tokenizer(target_text)) text = prompt_text + target_text
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat( text_token = torch.cat(
[ [
text_token, text_token,
@@ -417,6 +629,8 @@ class VoxCPMModel(nn.Module):
], ],
dim=-1, dim=-1,
) )
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0) audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0] text_length = text_token.shape[0]
@@ -433,42 +647,63 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device) text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device) audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference # run inference
target_text_length = len(self.text_tokenizer(target_text)) target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0 retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times: while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference( inference_result = self._inference(
text_token, text_token,
text_mask, text_mask,
audio_feat, audio_feat,
audio_mask, audio_mask,
min_len=min_len, min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len, max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming,
) )
if retry_badcase: if streaming:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold: patch_len = self.patch_size * self.chunk_size
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...") for latent_pred, pred_audio_feat in inference_result:
retry_badcase_times += 1 decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
continue decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else: else:
break break
else: if not streaming:
break decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
yield (
return ( decode_audio,
decode_audio, target_text_token,
target_text_token, pred_audio_feat
pred_audio_feat )
)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@torch.inference_mode() @torch.inference_mode()
def inference( def _inference(
self, self,
text: torch.Tensor, text: torch.Tensor,
text_mask: torch.Tensor, text_mask: torch.Tensor,
@@ -478,7 +713,9 @@ class VoxCPMModel(nn.Module):
max_len: int = 2000, max_len: int = 2000,
inference_timesteps: int = 10, inference_timesteps: int = 10,
cfg_value: float = 2.0, cfg_value: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]: streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation. """Core inference method for audio generation.
This is the main inference loop that generates audio features This is the main inference loop that generates audio features
@@ -493,11 +730,12 @@ class VoxCPMModel(nn.Module):
max_len: Maximum generation length max_len: Maximum generation length
inference_timesteps: Number of diffusion steps inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns: Returns:
Tuple containing: Generator of Tuple containing:
- Predicted latent features - Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence - Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
""" """
B, T, P, D = feat.shape B, T, P, D = feat.shape
@@ -549,11 +787,18 @@ class VoxCPMModel(nn.Module):
1, 2 1, 2
) # [b, p, d] ) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed) curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item() stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1: if i > min_len and stop_flag == 1:
@@ -569,37 +814,140 @@ class VoxCPMModel(nn.Module):
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device) lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone() ).clone()
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token yield feat_pred, pred_feat_seq.squeeze(0).cpu()
return feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod @classmethod
def from_local(cls, path: str): def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read()) config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path) tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae = AudioVAE() audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load( vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"), os.path.join(path, "audiovae.pth"),
map_location="cpu", map_location="cpu",
weights_only=True, weights_only=True,
)["state_dict"] )["state_dict"]
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae) if not training:
lm_dtype = get_dtype(config.dtype) lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype) model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32) model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load( # Try to load from safetensors first, fallback to pytorch_model.bin
os.path.join(path, "pytorch_model.bin"), safetensors_path = os.path.join(path, "model.safetensors")
map_location="cpu", pytorch_model_path = os.path.join(path, "pytorch_model.bin")
weights_only=True,
)["state_dict"] if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}")
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",
weights_only=True,
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
for kw, val in vae_state_dict.items(): for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
return model.to(model.device).eval().optimize() # LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA Weight Management
# ------------------------------------------------------------------ #
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
def load_lora_weights(self, lora_path: str, device: str = None):
"""
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
if target_key:
model_params[target_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""Enable/disable all LoRA layers."""
for module in self._iter_lora_modules():
module.set_enabled(enabled)
def reset_lora_weights(self):
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
for module in self._iter_lora_modules():
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}

View File

@@ -1 +1 @@
from .audio_vae import AudioVAE from .audio_vae import AudioVAE, AudioVAEConfig

View File

@@ -1,11 +1,12 @@
import math import math
from typing import List, Union from typing import List, Union, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import weight_norm from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs): def WNConv1d(*args, **kwargs):
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
return self.model(x) return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 1536
decoder_rates: List[int] = [8, 8, 5, 2]
depthwise: bool = True
sample_rate: int = 16000
use_noise_block: bool = False
class AudioVAE(nn.Module): class AudioVAE(nn.Module):
""" """
Args: Args:
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
def __init__( def __init__(
self, self,
encoder_dim: int = 128, config: Optional[AudioVAEConfig] = None,
encoder_rates: List[int] = [2, 5, 8, 8],
latent_dim: int = 64,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 5, 2],
depthwise: bool = True,
sample_rate: int = 16000,
use_noise_block: bool = False,
): ):
# 如果没有传入config使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__() super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
use_noise_block = config.use_noise_block
self.encoder_dim = encoder_dim self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim self.decoder_dim = decoder_dim

View File

@@ -0,0 +1,133 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRALinear(nn.Module):
"""
LoRA 线性层:直接持有 weight/bias保持与 nn.Linear 相同的 state_dict key 结构。
state_dict 结构:
- weight: 原始权重(与 nn.Linear 一致)
- bias: 原始偏置(与 nn.Linear 一致)
- lora_A: LoRA 低秩矩阵 A
- lora_B: LoRA 低秩矩阵 B
这样设计的好处:加载预训练权重时无需做 key 转换。
"""
def __init__(
self,
base: nn.Linear,
r: int,
alpha: float = 1.0,
dropout: float = 0.0,
):
super().__init__()
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
self.in_features = base.in_features
self.out_features = base.out_features
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
# 直接持有 weight 和 bias从原始 Linear 转移过来)
self.weight = base.weight
self.bias = base.bias # 可能是 None
# LoRA 参数
if r > 0:
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
else:
self.register_parameter("lora_A", None)
self.register_parameter("lora_B", None)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 基础 Linear 计算
result = F.linear(x, self.weight, self.bias)
if self.r <= 0 or self.lora_A is None:
return result
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
return result + self.dropout(lora_out) * self.scaling
def reset_lora_parameters(self):
"""重置 LoRA 参数到初始状态"""
if self.r > 0 and self.lora_A is not None:
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def set_enabled(self, enabled: bool):
"""启用/禁用 LoRA通过 scaling 控制,兼容 torch.compile"""
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
self.scaling.fill_(self._base_scaling if enabled else 0.0)
@property
def enabled(self) -> bool:
return self.scaling.item() != 0.0
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
"""
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module即 q_proj 的上一级)。
"""
parts = name.split(".")
if len(parts) == 1:
return root
parent = root
for p in parts[:-1]:
if not hasattr(parent, p):
return None
parent = getattr(parent, p)
return parent
def apply_lora_to_named_linear_modules(
root: nn.Module,
*,
target_submodule_names: list[str],
r: int,
alpha: float,
dropout: float,
) -> None:
"""
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
例如 target_submodule_names=["q_proj", "v_proj"] 时,
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
"""
for full_name, module in list(root.named_modules()):
if not isinstance(module, nn.Linear):
continue
short_name = full_name.split(".")[-1]
if short_name not in target_submodule_names:
continue
parent = _get_parent_module(root, full_name)
if parent is None:
continue
# 用 LoRALinear 替换原始 Linear
lora_layer = LoRALinear(
base=module,
r=r,
alpha=alpha,
dropout=dropout,
)
setattr(parent, short_name, lora_layer)

View File

@@ -1,20 +1,29 @@
from typing import List, Tuple
import torch import torch
from typing import List import torch.nn.functional as F
from .local_dit import VoxCPMLocDiT from torch.func import jvp
import math
from pydantic import BaseModel from pydantic import BaseModel
from .local_dit import VoxCPMLocDiT
class CfmConfig(BaseModel): class CfmConfig(BaseModel):
sigma_min: float = 1e-06 sigma_min: float = 1e-6
solver: str = "euler" solver: str = "euler"
t_scheduler: str = "log-norm" t_scheduler: str = "log-norm"
training_cfg_rate: float = 0.1
inference_cfg_rate: float = 1.0
reg_loss_type: str = "l1"
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
noise_cond_scale: float = 0.0
class UnifiedCFM(torch.nn.Module): class UnifiedCFM(torch.nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels: int,
cfm_params: CfmConfig, cfm_params: CfmConfig,
estimator: VoxCPMLocDiT, estimator: VoxCPMLocDiT,
mean_mode: bool = False, mean_mode: bool = False,
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
self.solver = cfm_params.solver self.solver = cfm_params.solver
self.sigma_min = cfm_params.sigma_min self.sigma_min = cfm_params.sigma_min
self.t_scheduler = cfm_params.t_scheduler self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
self.reg_loss_type = cfm_params.reg_loss_type
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
self.noise_cond_scale = cfm_params.noise_cond_scale
self.in_channels = in_channels self.in_channels = in_channels
self.mean_mode = mean_mode self.mean_mode = mean_mode
# Just change the architecture of the estimator here
self.estimator = estimator self.estimator = estimator
# ------------------------------------------------------------------ #
# Inference
# ------------------------------------------------------------------ #
@torch.inference_mode() @torch.inference_mode()
def forward( def forward(
self, self,
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
sway_sampling_coef: float = 1.0, sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True, use_cfg_zero_star: bool = True,
): ):
"""Forward diffusion b, _ = mu.shape
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
n_timesteps (int): number of diffusion steps
cond: Not used but kept for future purposes
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
b, c = mu.shape
t = patch_size t = patch_size
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype) t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# Sway sampling strategy
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star) return self.solve_euler(
x=z,
t_span=t_span,
mu=mu,
cond=cond,
cfg_value=cfg_value,
use_cfg_zero_star=use_cfg_zero_star,
)
def optimized_scale(self, positive_flat, negative_flat): def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
st_star = dot_product / squared_norm st_star = dot_product / squared_norm
return st_star return st_star
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
cfg_value: float = 1.0, cfg_value: float = 1.0,
use_cfg_zero_star: bool = True, use_cfg_zero_star: bool = True,
): ):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
cond: condition -- prefix prompt
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
"""
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1] t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
sol = [] sol = []
zero_init_steps = max(1, int(len(t_span) * 0.04)) zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)): for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps: if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = 0. dphi_dt = torch.zeros_like(x)
else: else:
# Classifier-Free Guidance inference introduced in VoiceBox # Classifier-Free Guidance inference introduced in VoiceBox
b = x.size(0) b = x.size(0)
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype) mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype) t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype) dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype) cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
x_in[:b], x_in[b:] = x, x x_in[:b], x_in[b:] = x, x
mu_in[:b] = mu mu_in[:b] = mu
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0) t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
dt = t - t_span[step + 1] dt = t - t_span[step + 1]
return sol[-1] return sol[-1]
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
return weights.detach()
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
batch_size = x.shape[0]
if self.t_scheduler == "log-norm":
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
r = torch.sigmoid(s_r)
t = torch.sigmoid(s_t)
elif self.t_scheduler == "uniform":
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
else:
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
r, t = torch.where(
mask,
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
torch.stack([t, t], dim=0),
)
return r.squeeze(), t.squeeze()
def compute_loss(
self,
x1: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
progress: float = 0.0,
):
b, _, _ = x1.shape
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1)
if cond is None:
cond = torch.zeros_like(x1)
noisy_mask = torch.rand(b, device=x1.device) > (
1.0
- (
self.noise_cond_prob_range[0]
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
)
)
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
r_ = r.detach().clone()
t_ = t.detach().clone()
z = torch.randn_like(x1)
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
v = z - x1
def model_fn(z_sample, r_sample, t_sample):
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
if self.mean_mode:
v_r = torch.zeros_like(r)
v_t = torch.ones_like(t)
from torch.backends.cuda import sdp_kernel
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
else:
u_pred = model_fn(y, r, t)
u_tgt = v
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask)
else:
loss = losses.mean()
return loss

View File

@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_cache = key_cache.contiguous()
value_cache = value_cache.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_cache, key_cache,

View File

@@ -0,0 +1,28 @@
"""
Training utilities for VoxCPM fine-tuning.
This package mirrors the training mechanics used in the minicpm-audio
tooling while relying solely on local audio-text datasets managed via
the HuggingFace ``datasets`` library.
"""
from .accelerator import Accelerator
from .tracker import TrainingTracker
from .data import (
load_audio_text_datasets,
HFVoxCPMDataset,
build_dataloader,
BatchProcessor,
)
from .state import TrainingState
__all__ = [
"Accelerator",
"TrainingTracker",
"HFVoxCPMDataset",
"BatchProcessor",
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
]

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
import contextlib
import os
import random
import typing
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
class Accelerator:
"""
Simplified accelerator that mirrors the behaviour of the minicpm-audio
training utilities. It initializes a distributed process group when
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
preparing models/dataloaders for DDP.
"""
def __init__(self, amp: bool = False, seed: int = 42):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size > 1 and not dist.is_initialized():
dist.init_process_group("nccl", init_method="env://")
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.amp = amp
# Set random seed to ensure model initialization consistency
self._set_seed(seed)
class DummyScaler:
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def barrier(self):
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
"""All-reduce tensor across processes"""
if dist.is_initialized():
dist.all_reduce(tensor, op=op)
return tensor
# ------------------------------------------------------------------ #
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
self._ddp_model = model # Save DDP model reference for no_sync support
return model
@contextlib.contextmanager
def no_sync(self):
"""
Context manager to skip gradient synchronization during gradient accumulation.
Only used outside the last micro-batch.
"""
if self._ddp_model is not None:
with self._ddp_model.no_sync():
yield
else:
yield
@property
def device(self):
if torch.cuda.is_available():
return torch.device("cuda", self.local_rank)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ------------------------------------------------------------------ #
# AMP helpers
# ------------------------------------------------------------------ #
def autocast(self, *args, **kwargs):
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
self.scaler.step(optimizer)
def update(self):
self.scaler.update()
# ------------------------------------------------------------------ #
# Data helpers
# ------------------------------------------------------------------ #
def prepare_dataloader(
self,
dataset: typing.Iterable,
*,
batch_size: int,
num_workers: int = 0,
shuffle: bool = True,
collate_fn=None,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
if self.world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
)
shuffle = False
else:
sampler = None
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
pin_memory=True,
)
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import argbind
import yaml
from pathlib import Path
from typing import Dict, Any
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
"""
Load a YAML configuration file into a dictionary suitable for argbind.
"""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
return data
def parse_args_with_config(config_path: str | Path | None = None):
"""
Helper to unify CLI arguments and YAML configuration.
Usage mirrors minicpm-audio:
args = parse_args_with_config("conf/voxcpm/finetune.yml")
with argbind.scope(args):
...
"""
cli_args = argbind.parse_args()
if config_path is None:
return cli_args
yaml_args = load_yaml_config(config_path)
with argbind.scope(cli_args):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args

214
src/voxcpm/training/data.py Normal file
View File

@@ -0,0 +1,214 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
import torch
from datasets import Audio, Dataset, DatasetDict, load_dataset
from torch.utils.data import Dataset as TorchDataset
from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@argbind.bind()
def load_audio_text_datasets(
train_manifest: str,
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
) -> Tuple[Dataset, Optional[Dataset]]:
data_files = {"train": train_manifest}
if val_manifest:
data_files["validation"] = val_manifest
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
else:
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
return ds
train_ds = prepare(dataset_dict["train"])
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
return train_ds, val_ds
def compute_sample_lengths(
ds: Dataset,
audio_vae_fps: int = 25,
patch_size: int = 1,
) -> List[int]:
"""
预估每个样本经过 packer 之后的大致序列长度text+audio用于过滤超长样本。
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
- 文本长度: len(text_ids)
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
"""
lengths: List[int] = []
has_duration = "duration" in ds.column_names
for i in range(len(ds)):
item = ds[i]
text_len = len(item["text_ids"])
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
if has_duration:
duration = float(item["duration"])
else:
audio = item[DEFAULT_AUDIO_COLUMN]
duration = len(audio["array"]) / float(audio["sampling_rate"])
t_vae = math.ceil(duration * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
lengths.append(total_len)
return lengths
class HFVoxCPMDataset(TorchDataset):
"""
Thin wrapper around a tokenized HuggingFace dataset that returns
PyTorch-friendly samples.
"""
def __init__(self, dataset: Dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
if not seqs:
return torch.empty(0)
max_len = max(seq.shape[0] for seq in seqs)
padded = []
for seq in seqs:
if seq.shape[0] < max_len:
pad_width = (0, max_len - seq.shape[0])
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
padded.append(seq)
return torch.stack(padded)
@classmethod
def collate_fn(cls, batch: List[Dict]):
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
"dataset_ids": dataset_ids,
"is_prompts": is_prompts,
}
class BatchProcessor:
"""
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
the minicpm-audio mechanics.
"""
def __init__(
self,
*,
config: VoxCPMConfig,
audio_vae: AudioVAE,
dataset_cnt: int,
device: torch.device,
):
self.device = device
self.dataset_cnt = dataset_cnt
self.audio_vae = audio_vae
self.audio_vae.to(device)
self.packer = AudioFeatureProcessingPacker(
dataset_cnt=dataset_cnt,
max_len=config.max_length,
patch_size=config.patch_size,
feat_dim=config.feat_dim,
audio_vae=self.audio_vae,
)
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
audio_tokens = batch["audio_tokens"].to(self.device)
text_tokens = batch["text_tokens"].to(self.device)
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
)
return packed
def build_dataloader(
hf_dataset: Dataset,
*,
accelerator,
batch_size: int,
num_workers: int,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
torch_dataset = HFVoxCPMDataset(hf_dataset)
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
return accelerator.prepare_dataloader(
torch_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last,
)

View File

@@ -0,0 +1,289 @@
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from einops import rearrange
class AudioFeatureProcessingPacker:
"""
Adapted from the minicpm-audio training utilities. It converts raw text and
audio tokens into the packed multimodal representation required by VoxCPM.
"""
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
self.patch_size = patch_size
self.patch_len = audio_vae.hop_length * self.patch_size
self.feat_dim = feat_dim
self.dataset_cnt = max(dataset_cnt, 1)
self.max_len = max_len
self.audio_vae = audio_vae
self.process_functions = {"tts": self.process_tts_data}
self.task_id_map = {"tts": 1}
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _first_pad_position(tokens: torch.Tensor):
positions = (tokens == -100).nonzero(as_tuple=True)
if positions[0].numel() == 0:
return None
return int(positions[0][0])
def unpad_text_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def unpad_audio_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def encode_audio(self, wav: torch.Tensor):
"""
Encode raw waveform into latent features using AudioVAE.
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
We then transpose to [B, T, D] to match downstream expectations.
"""
wav = wav.unsqueeze(0) # [1, T]
wav = wav.unsqueeze(1) # [1, 1, T]
wav_len = wav.size(-1)
if wav_len % self.patch_len != 0:
padding_size = self.patch_len - wav_len % self.patch_len
wav = torch.nn.functional.pad(wav, (0, padding_size))
with torch.no_grad():
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
feat = z.transpose(1, 2) # [1, T', D]
return feat
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
def __call__(
self,
audio_tokens: torch.Tensor,
text_tokens: torch.Tensor,
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
text_tokens_list: List[torch.Tensor] = []
audio_feats_list: List[torch.Tensor] = []
text_mask_list: List[torch.Tensor] = []
audio_mask_list: List[torch.Tensor] = []
loss_mask_list: List[torch.Tensor] = []
labels_list: List[torch.Tensor] = []
audio_task_ids_list: List[torch.Tensor] = []
audio_dataset_ids_list: List[torch.Tensor] = []
lengths: List[int] = []
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
audio_duration_consumed[dataset_idx] += audio_duration
text_token_consumed[dataset_idx] += text_token_count
audio_task_id = torch.zeros_like(audio_mask)
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
audio_dataset_id = torch.zeros_like(audio_mask)
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
text_tokens_list.append(packed_text)
text_mask_list.append(text_mask)
audio_feats_list.append(audio_feat)
audio_mask_list.append(audio_mask)
loss_mask_list.append(loss_mask)
labels_list.append(labels)
audio_task_ids_list.append(audio_task_id)
audio_dataset_ids_list.append(audio_dataset_id)
lengths.append(packed_text.shape[0])
# Determine padded length per batch (cap by self.max_len)
if lengths:
max_len = min(self.max_len, max(lengths))
else:
max_len = self.max_len
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
for L in lengths:
L_clip = min(L, max_len)
pos = torch.arange(0, L_clip, device=device)
if L_clip < max_len:
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
pos = torch.cat([pos, pad], dim=0)
position_ids_list.append(pos)
position_ids = torch.stack(position_ids_list, dim=0)
else:
# Empty batch fallback (shouldn't really happen)
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
text_mask_batch = torch.zeros_like(text_tokens_batch)
audio_feats_batch = torch.zeros(
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
)
audio_mask_batch = torch.zeros_like(text_tokens_batch)
loss_mask_batch = torch.zeros_like(text_tokens_batch)
labels_batch = torch.zeros_like(text_tokens_batch)
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
position_ids = torch.zeros_like(text_tokens_batch)
audio_duration_consumed = audio_duration_consumed.to(torch.long)
text_token_consumed = text_token_consumed.to(torch.long)
return {
"text_tokens": text_tokens_batch,
"audio_feats": audio_feats_batch,
"text_mask": text_mask_batch,
"audio_mask": audio_mask_batch,
"loss_mask": loss_mask_batch,
"position_ids": position_ids,
"labels": labels_batch,
"audio_task_ids": audio_task_ids_batch,
"audio_dataset_ids": audio_dataset_ids_batch,
"audio_duration_consumed": audio_duration_consumed,
"text_token_consumed": text_token_consumed,
}
# ------------------------------------------------------------------ #
# Feature extraction helpers
# ------------------------------------------------------------------ #
def extract_audio_feats(self, audio_data: torch.Tensor):
audio_feats = self.encode_audio(audio_data)
if audio_feats.size(1) % self.patch_size != 0:
audio_feats_ = audio_feats.transpose(1, 2)
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
audio_feats = padding.transpose(1, 2)
audio_duration = audio_feats.size(1) / 25
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
return audio_feats, audio_duration
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
text_token_info = torch.cat(
[
text_token,
torch.tensor(
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_token_count = len(text_token)
text_length = text_token_info.shape[0]
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
audio_feat_info = audio_feat_info.squeeze(0)
audio_length = audio_feat_info.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token_info = torch.cat(
[
text_token_info,
text_pad_token,
torch.tensor(
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
dtype=torch.int32,
device=text_token.device,
),
]
)
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, audio_feat_info.size(-1)),
dtype=torch.float32,
device=text_token.device,
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class TrainingState:
"""
Container that mirrors the object returned in the minicpm-audio training
loop. It holds persistent references to the model, optimizer, scheduler,
dataloaders and tracker.
"""
generator: object
optimizer: object
scheduler: object
train_loader: object
val_loader: object
tracker: object
batch_processor: object

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import contextlib
import time
from pathlib import Path
from typing import Dict, Optional
class TrainingTracker:
"""
Lightweight tracker inspired by the minimcpm-audio training workflow.
It keeps track of the current global step, prints rank-aware messages,
optionally writes to TensorBoard via a provided writer, and stores progress
in a logfile for later inspection.
"""
def __init__(
self,
*,
writer=None,
log_file: Optional[str] = None,
rank: int = 0,
):
self.writer = writer
self.log_file = Path(log_file) if log_file else None
if self.log_file:
self.log_file.parent.mkdir(parents=True, exist_ok=True)
self.rank = rank
self.step = 0
# Record the time of the last log to calculate the interval
self._last_log_time: float | None = None
# ------------------------------------------------------------------ #
# Logging helpers
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")
def log_metrics(self, metrics: Dict[str, float], split: str):
if self.rank == 0:
now = time.time()
dt_str = ""
if self._last_log_time is not None:
dt = now - self._last_log_time
dt_str = f", log interval: {dt:.2f}s"
self._last_log_time = now
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
if self.writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{split}/{key}", value, self.step)
def done(self, split: str, message: str):
self.print(f"[{split}] {message}")
# ------------------------------------------------------------------ #
# State dict
# ------------------------------------------------------------------ #
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state):
self.step = int(state.get("step", 0))
# ------------------------------------------------------------------ #
# Context manager compatibility (for parity with minicpm-audio code)
# ------------------------------------------------------------------ #
@contextlib.contextmanager
def live(self):
yield

View File

@@ -3,41 +3,8 @@ import re
import regex import regex
import inflect import inflect
from functools import partial from functools import partial
from tn.chinese.normalizer import Normalizer as ZhNormalizer from wetext import Normalizer
from tn.english.normalizer import Normalizer as EnNormalizer
def normal_cut_sentence(text):
# 先处理括号内的逗号,将其替换为特殊标记
text = re.sub(r'([(][^)]*)([,])([^)]*[)])', r'\1&&&\3', text)
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
text = re.sub('(\{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
text = re.sub('([. ,。!;?\?\.{6}\{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
# 处理英文句子的分隔
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
text = re.sub(r'([(][^)]*)(&&&)([^)]*[)])', r'\1\3', text)
text = [t for t in text.split("\n") if t]
return text
def cut_sentence_with_fix_length(text : str, length : int):
sentences = normal_cut_sentence(text)
cur_length = 0
res = ""
for sentence in sentences:
if not sentence:
continue
if cur_length > length or cur_length + len(sentence) > length:
yield res
res = ""
cur_length = 0
res += sentence
cur_length += len(sentence)
if res:
yield res
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
# whether contain chinese character # whether contain chinese character
@@ -195,8 +162,8 @@ def clean_text(text):
class TextNormalizer: class TextNormalizer:
def __init__(self, tokenizer=None): def __init__(self, tokenizer=None):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True) self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
self.en_tn_model = EnNormalizer() self.en_tn_model = Normalizer(lang="en", operator="tn")
self.inflect_parser = inflect.engine() self.inflect_parser = inflect.engine()
def normalize(self, text, split=False): def normalize(self, text, split=False):
@@ -207,38 +174,12 @@ class TextNormalizer:
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“ text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减 if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2 text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text)
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
text = self.zh_tn_model.normalize(text) text = self.zh_tn_model.normalize(text)
text = replace_blank(text) text = replace_blank(text)
text = replace_corner_mark(text) text = replace_corner_mark(text)
text = remove_bracket(text) text = remove_bracket(text)
text = re.sub(r'[,]+$', '', text)
else: else:
text = self.en_tn_model.normalize(text) text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser) text = spell_out_number(text, self.inflect_parser)
if split is False: if split is False:
return text return text
if __name__ == "__main__":
text_normalizer = TextNormalizer()
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
其中aaa、bbb 和 ccc 是常数xxx 是变量。这个方程的解可以通过求根公式来找到。
一元二次方程的解法有几种:
- 因式分解法通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式解决方程的解。比如方程x25x+6=0x^2 - 5x + 6 = 0x25x+6=0可以因式分解为(x2)(x3)=0(x - 2)(x - 3) = 0(x2)(x3)=0因此根为2和3。
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
公式x=b±b24ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2ab±b24ac这个公式可以帮助我们求解任何一元二次方程的根。
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b24ac\Delta = b^2 - 4acΔ=b24ac 根据判别式的值,我们可以知道:
- 如果 Δ>0\Delta > 0Δ>0方程有两个不相等的实数解。这是因为判别式大于0时根号内的值是正数所以我们可以得到两个不同的解。
- 如果 Δ=0\Delta = 0Δ=0方程有一个实数解。这是因为根号内的值为零导致两个解相等也就是说方程有一个解。
- 如果 Δ<0\Delta < 0Δ<0方程没有实数解。这意味着根号内的值是负数无法进行实数运算因此方程没有实数解可能有复数解。"""
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战", "550 + 320 = 870千卡。", "解一元二次方程3x^2+x-2=0", "你好啊"]
texts = [text]
for text in texts:
text = text_normalizer.normalize(text)
print(text)
for t in cut_sentence_with_fix_length(text, 15):
print(t)