21 Commits
1.0.4 ... main

Author SHA1 Message Date
刘鑫
aabda60833 add lora finetune data setting QA 2025-12-10 20:25:24 +08:00
刘鑫
a266c0a88d add lora funetine webUI; optimize lora save and load logic 2025-12-09 21:34:39 +08:00
Labmem-Zhouyx
0779a93697 Merge branch 'main' of https://github.com/OpenBMB/VoxCPM 2025-12-07 02:02:08 +08:00
Labmem-Zhouyx
a1f9d0c3b6 Update: release note 2025-12-07 01:59:53 +08:00
xliucs
aefba63f71 Merge pull request #98 from Ayin1412/main
修复lora/ft测试代码处传参错误的内容
2025-12-06 17:38:19 +08:00
Ayin1412
58717d7d82 修复lora/ft测试代码处传参错误的内容 2025-12-06 14:49:35 +08:00
Labmem-Zhouyx
1b0ff5693c Update: model parameters 2025-12-06 01:22:30 +08:00
Labmem-Zhouyx
762815a5b7 Update: user guides 2025-12-05 23:57:43 +08:00
Labmem-Zhouyx
5b13a35ea6 Update: gradio description 2025-12-05 23:47:35 +08:00
Labmem-Zhouyx
3ba727a615 Update: gradio description 2025-12-05 23:38:04 +08:00
Labmem-Zhouyx
a7a447b02a Merge branch 'dev_1.5'
# Conflicts:
#	README.md
#	docs/finetune.md
#	scripts/test_voxcpm_ft_infer.py
#	scripts/test_voxcpm_lora_infer.py
#	src/voxcpm/core.py
2025-12-05 22:38:03 +08:00
刘鑫
400f47a516 Modify lora inference api 2025-12-05 22:22:13 +08:00
Labmem-Zhouyx
b1f7593ae0 Update: default no denoise & normalize 2025-12-05 22:16:27 +08:00
Labmem-Zhouyx
6a5e713698 fix: streaming mode 2025-12-05 22:06:15 +08:00
Labmem-Zhouyx
3443dbb212 Update: VoxCPM1.5 and fine-tuning supprt 2025-12-05 21:04:51 +08:00
Labmem-Zhouyx
461ad7e506 Update: VoxCPM1.5 and fine-tuning supprt 2025-12-05 21:00:01 +08:00
Labmem-Zhouyx
d1bb6aaf41 update technical report 2025-09-30 10:47:39 +08:00
刘鑫
2eb4d39719 FX: Add MPS support 2025-09-28 21:06:35 +08:00
刘鑫
fbf8984d4e Merge branch 'main' into dev 2025-09-27 16:20:47 +08:00
刘鑫
961569e76d merge from main 2025-09-19 22:08:56 +08:00
刘鑫
f26a1ea2f7 Remove segment text logic 2025-09-18 12:01:26 +08:00
32 changed files with 4658 additions and 265 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
launch.json
__pycache__
voxcpm.egg-info
voxcpm.egg-info
.DS_Store

187
README.md
View File

@@ -1,7 +1,11 @@
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
[![Project Page](https://img.shields.io/badge/Project%20Page-GitHub-blue)](https://github.com/OpenBMB/VoxCPM/) [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OpenBMB-yellow)](https://huggingface.co/openbmb/VoxCPM-0.5B) [![ModelScope](https://img.shields.io/badge/ModelScope-OpenBMB-purple)](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) [![Live Playground](https://img.shields.io/badge/Live%20PlayGround-Demo-orange)](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [![Samples](https://img.shields.io/badge/Page-Samples-red)](https://openbmb.github.io/VoxCPM-demopage)
[![Project Page](https://img.shields.io/badge/Project%20Page-GitHub-blue)](https://github.com/OpenBMB/VoxCPM/) [![Technical Report](https://img.shields.io/badge/Technical%20Report-Arxiv-red)](https://arxiv.org/abs/2509.24650)[![Live Playground](https://img.shields.io/badge/Live%20PlayGround-Demo-orange)](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [![Samples](https://img.shields.io/badge/Audio%20Samples-Page-green)](https://openbmb.github.io/VoxCPM-demopage)
#### VoxCPM1.5 Model Weights
[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-OpenBMB-yellow)](https://huggingface.co/openbmb/VoxCPM1.5) [![ModelScope](https://img.shields.io/badge/ModelScope-OpenBMB-purple)](https://modelscope.cn/models/OpenBMB/VoxCPM1.5)
@@ -16,6 +20,8 @@
</div>
## News
* [2025.12.05] 🎉 🎉 🎉 We Open Source the VoxCPM1.5 [weights](https://huggingface.co/openbmb/VoxCPM1.5)! The model now supports both full-parameter fine-tuning and efficient LoRA fine-tuning, empowering you to create your own tailored version. See [Release Notes](docs/release_note.md) for details.
* [2025.09.30] 🔥 🔥 🔥 We Release VoxCPM [Technical Report](https://arxiv.org/abs/2509.24650)!
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)!
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
@@ -32,10 +38,22 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
### 🚀 Key Features
- **Context-Aware, Expressive Speech Generation** - VoxCPM comprehends text to infer and generate appropriate prosody, delivering speech with remarkable expressiveness and natural flow. It spontaneously adapts speaking style based on content, producing highly fitting vocal expression trained on a massive 1.8 million-hour bilingual corpus.
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speakers timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speaker's timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
- **High-Efficiency Synthesis** - VoxCPM supports streaming synthesis with a Real-Time Factor (RTF) as low as 0.17 on a consumer-grade NVIDIA RTX 4090 GPU, making it possible for real-time applications.
### 📦 Model Versions
See [Release Notes](docs/release_note.md) for details
- **VoxCPM1.5** (Latest):
- Model Params: 800M
- Sampling rate of AudioVAE: 44100
- Token rate in LM Backbone: 6.25Hz (patch-size=4)
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
- **VoxCPM-0.5B** (Original):
- Model Params: 640M
- Sampling rate of AudioVAE: 16000
- Token rate in LM Backbone: 12.5Hz (patch-size=2)
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17
@@ -47,12 +65,18 @@ pip install voxcpm
```
### 1. Model Download (Optional)
By default, when you first run the script, the model will be downloaded automatically, but you can also download the model in advance.
- Download VoxCPM-0.5B
- Download VoxCPM1.5
```
from huggingface_hub import snapshot_download
snapshot_download("openbmb/VoxCPM1.5")
```
- Or Download VoxCPM-0.5B
```
from huggingface_hub import snapshot_download
snapshot_download("openbmb/VoxCPM-0.5B")
```
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
```
from modelscope import snapshot_download
snapshot_download('iic/speech_zipenhancer_ans_multiloss_16k_base')
@@ -65,7 +89,7 @@ import soundfile as sf
import numpy as np
from voxcpm import VoxCPM
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
model = VoxCPM.from_pretrained("openbmb/VoxCPM1.5")
# Non-streaming
wav = model.generate(
@@ -74,14 +98,14 @@ wav = model.generate(
prompt_text=None, # optional: reference text
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
normalize=True, # enable external TN tool
denoise=True, # enable external Denoise tool
normalize=False, # enable external TN tool, but will disable native raw text support
denoise=False, # enable external Denoise tool, but it may cause some distortion and restrict the sampling rate to 16kHz
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
retry_badcase_max_times=3, # maximum retrying times
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
)
sf.write("output.wav", wav, 16000)
sf.write("output.wav", wav, model.tts_model.sample_rate)
print("saved: output.wav")
# Streaming
@@ -93,7 +117,7 @@ for chunk in model.generate_streaming(
chunks.append(chunk)
wav = np.concatenate(chunks)
sf.write("output_streaming.wav", wav, 16000)
sf.write("output_streaming.wav", wav, model.tts_model.sample_rate)
print("saved: output_streaming.wav")
```
@@ -110,14 +134,14 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
--prompt-audio path/to/voice.wav \
--prompt-text "reference transcript" \
--output out.wav \
--denoise
# --denoise
# (Optinal) Voice cloning (reference audio + transcript file)
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
--prompt-audio path/to/voice.wav \
--prompt-file "/path/to/text-file" \
--output out.wav \
--denoise
# --denoise
# 3) Batch processing (one text per line)
voxcpm --input examples/input.txt --output-dir outs
@@ -125,7 +149,7 @@ voxcpm --input examples/input.txt --output-dir outs
voxcpm --input examples/input.txt --output-dir outs \
--prompt-audio path/to/voice.wav \
--prompt-text "reference transcript" \
--denoise
# --denoise
# 4) Inference parameters (quality/speed)
voxcpm --text "..." --output out.wav \
@@ -136,7 +160,7 @@ voxcpm --text "..." --output out.wav \
voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir
# Or from Hugging Face (auto download/cache)
voxcpm --text "..." --output out.wav \
--hf-model-id openbmb/VoxCPM-0.5B --cache-dir ~/.cache/huggingface --local-files-only
--hf-model-id openbmb/VoxCPM1.5 --cache-dir ~/.cache/huggingface --local-files-only
# 6) Denoiser control
voxcpm --text "..." --output out.wav \
@@ -151,104 +175,53 @@ python -m voxcpm.cli --help
You can start the UI interface by running `python app.py`, which allows you to perform Voice Cloning and Voice Creation.
### 5. Fine-tuning
VoxCPM1.5 supports both full fine-tuning (SFT) and LoRA fine-tuning, allowing you to train personalized voice models on your own data. See the [Fine-tuning Guide](docs/finetune.md) for detailed instructions.
## 👩‍🍳 A Voice Chef's Guide
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Lets begin.
**Quick Start:**
```bash
# Full fine-tuning
python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
---
### 🥚 Step 1: Prepare Your Base Ingredients (Content)
# LoRA fine-tuning
python scripts/train_voxcpm_finetune.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
```
First, choose how youd like to input your text:.
1. Regular Text (Classic Mode)
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
2. Phoneme Input (Native Mode)
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like {HH AH0 L OW1} (EN) or {ni3}{hao3} (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
---
### 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
This is the secret sauce that gives your audio its unique sound.
1. Cooking with a Prompt Speech (Following a Famous Recipe)
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
- For a Clean, Studio-Quality Voice:
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
2. Cooking au Naturel (Letting the Model Improvise)
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
- Pro Tip: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
---
### 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
- CFG Value (How Closely to Follow the Recipe)
- Default: A great starting point.
- Voice sounds strained or weird? Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
- Need maximum clarity and adherence to the text? Raise it slightly to keep the model on a tighter leash.
- Inference Timesteps (Simmering Time: Quality vs. Speed)
- Need a quick snack? Use a lower number. Perfect for fast drafts and experiments.
- Cooking a gourmet meal? Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
---
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!
## 📚 Documentation
- **[Usage Guide](docs/usage_guide.md)** - Detailed guide on how to use VoxCPM effectively, including text input modes, voice cloning tips, and parameter tuning
- **[Fine-tuning Guide](docs/finetune.md)** - Complete guide for fine-tuning VoxCPM models with SFT and LoRA
- **[Release Notes](docs/release_note.md)** - Version history and updates
- **[Performance Benchmarks](docs/performance.md)** - Detailed performance comparisons on public benchmarks
---
## 📚 More Information
### 🌟 Community Projects
We're excited to see the VoxCPM community growing! Here are some amazing projects and features built by our community:
- **[ComfyUI-VoxCPM](https://github.com/wildminder/ComfyUI-VoxCPM)** A VoxCPM extension for ComfyUI.
- **[ComfyUI-VoxCPMTTS](https://github.com/1038lab/ComfyUI-VoxCPMTTS)** A VoxCPM extension for ComfyUI.
- **[WebUI-VoxCPM](https://github.com/rsxdalv/tts_webui_extension.vox_cpm)** A template extension for TTS WebUI.
- **[PR: Streaming API Support (by AbrahamSanders)](https://github.com/OpenBMB/VoxCPM/pull/26)**
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
- **[PR: LoRA finetune web UI (by Ayin1412)](https://github.com/OpenBMB/VoxCPM/pull/100)**
- **[voxcpm_rs](https://github.com/madushan1000/voxcpm_rs)** A re-implementation of VoxCPM-0.5B in Rust.
## 📊 Performance Highlights
VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
### Seed-TTS-eval Benchmark
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
### CV3-eval Benchmark
| Model | zh | en | hard-zh | | | hard-en | | |
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | - | - | - |
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |
*Note: The projects are not officially maintained by OpenBMB.*
*Have you built something cool with VoxCPM? We'd love to feature it here! Please open an issue or pull request to add your project.*
### 📊 Performance Highlights
VoxCPM achieves competitive results on public zero-shot TTS benchmarks. See [Performance Benchmarks](docs/performance.md) for detailed comparison tables.
@@ -259,12 +232,15 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
- Bilingual Model: VoxCPM is trained primarily on Chinese and English data. Performance on other languages is not guaranteed and may result in unpredictable or low-quality audio.
- This model is released for research and development purposes only. We do not recommend its use in production or commercial applications without rigorous testing and safety evaluations. Please use VoxCPM responsibly.
---
## 📝TO-DO List
## 📝 TO-DO List
Please stay tuned for updates!
- [ ] Release the VoxCPM technical report.
- [ ] Support higher sampling rate (next version).
- [x] Release the VoxCPM technical report.
- [x] Support higher sampling rate (44.1kHz in VoxCPM-1.5).
- [x] Support SFT and LoRA fine-tuning.
- [] Multilingual Support (besides ZH/EN).
- [] Controllable Speech Generation by Human Instruction.
@@ -294,16 +270,13 @@ This project is developed by the following institutions:
## 📚 Citation
The techical report is coming soon, please wait for the release 😊
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
```bib
@misc{voxcpm2025,
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}},
title = {{VoxCPM}},
@article{voxcpm2025,
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning},
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
journal = {arXiv preprint arXiv:2509.24650},
year = {2025},
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
note = {GitHub repository}
}
```

30
app.py
View File

@@ -8,7 +8,7 @@ from funasr import AutoModel
from pathlib import Path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
if os.environ.get("HF_REPO_ID", "").strip() == "":
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM-0.5B"
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
import voxcpm
@@ -29,7 +29,7 @@ class VoxCPMDemo:
# TTS model (lazy init)
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
self.default_local_model_dir = "./models/VoxCPM-0.5B"
self.default_local_model_dir = "./models/VoxCPM1.5"
# ---------- Model helpers ----------
def _resolve_model_dir(self) -> str:
@@ -108,7 +108,7 @@ class VoxCPMDemo:
normalize=do_normalize,
denoise=denoise,
)
return (16000, wav)
return (current_model.tts_model.sample_rate, wav)
# ---------- UI Builders ----------
@@ -172,22 +172,22 @@ def create_demo_interface(demo: VoxCPMDemo):
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
gr.Markdown("""
### Prompt Speech Enhancement参考语音降噪
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质
- **Disable** to preserve the original audio's background atmosphere.
**禁用**:保留原始音频的背景环境声,如果想复刻相应声学环境
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz限制克隆上限
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
**禁用**:保留原始音频的全部信息包括背景环境声最高支持44.1kHz的音频复刻
### Text Normalization文本正则化
- **Enable** to process general text with an external WeTextProcessing component.
**启用**:使用 WeTextProcessing 组件,可处理常见文本。
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input ({HH AH0 L OW1}), try it!
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如 {da4}{jia1})和公式符号合成,尝试一下!
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it!
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3}英文转CMUDict{HH AH0 L OW1})和公式符号合成,尝试一下!
### CFG ValueCFG 值
- **Lower CFG** if the voice prompt sounds strained or expressive.
**调低**:如果提示语音听起来不自然或过于夸张。
- **Higher CFG** for better adherence to the prompt speech style or input text.
**调高**:为更好地贴合提示音频的风格或输入文本。
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input.
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input.
**调高**:为更好地贴合提示音频的风格或输入文本 或者极短文本输入出现稳定性问题
### Inference Timesteps推理时间步
- **Lower** for faster synthesis speed.
@@ -267,7 +267,7 @@ def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error
demo = VoxCPMDemo()
interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput
interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error)
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
if __name__ == "__main__":

View File

@@ -0,0 +1,21 @@
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0

View File

@@ -0,0 +1,36 @@
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
# LoRA configuration
lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
alpha: 16
dropout: 0.0
# Distribution options (optional)
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
# hf_model_id: "openbmb/VoxCPM1.5"
# distribute: true

View File

@@ -0,0 +1,21 @@
pretrained_path: /path/to/VoxCPM-0.5B/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 16000
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0

View File

@@ -0,0 +1,36 @@
pretrained_path: /path/to/VoxCPM-0.5B/
train_manifest: /path/to/train.jsonl
val_manifest: null
sample_rate: 16000
batch_size: 16
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
# LoRA configuration
lora:
enable_lm: true
enable_dit: true
enable_proj: false
r: 32
alpha: 16
dropout: 0.0
# Distribution options (optional)
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
# hf_model_id: "openbmb/VoxCPM-0.5B"
# distribute: true

468
docs/finetune.md Normal file
View File

@@ -0,0 +1,468 @@
# VoxCPM Fine-tuning Guide
This guide covers how to fine-tune VoxCPM models with two approaches: full fine-tuning and LoRA fine-tuning.
### 🎓 SFT (Supervised Fine-Tuning)
Full fine-tuning updates all model parameters. Suitable for:
- 📊 Large, specialized datasets
- 🔄 Cases where significant behavior changes are needed
### ⚡ LoRA Fine-tuning
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
- 🎯 Trains only a small number of additional parameters
- 💾 Significantly reduces memory requirements and training time
- 🔀 Supports multiple LoRA adapters with hot-swapping
## Table of Contents
- [Quick Start: WebUI](#quick-start-webui)
- [Data Preparation](#data-preparation)
- [Full Fine-tuning](#full-fine-tuning)
- [LoRA Fine-tuning](#lora-fine-tuning)
- [Inference](#inference)
- [LoRA Hot-swapping](#lora-hot-swapping)
- [FAQ](#faq)
---
## Quick Start: WebUI
For users who prefer a graphical interface, we provide `lora_ft_webui.py` - a comprehensive WebUI for training and inference:
### Launch WebUI
```bash
python lora_ft_webui.py
```
Then open `http://localhost:7860` in your browser.
### Features
- **🚀 Training Tab**: Configure and start LoRA training with an intuitive interface
- Set training parameters (learning rate, batch size, LoRA rank, etc.)
- Monitor training progress in real-time
- Resume training from existing checkpoints
- **🎵 Inference Tab**: Generate audio with trained models
- Automatic base model loading from LoRA checkpoint config
- Voice cloning with automatic ASR (reference text recognition)
- Hot-swap between multiple LoRA models
- Zero-shot TTS without reference audio
## Data Preparation
Training data should be prepared as a JSONL manifest file, with one sample per line:
```jsonl
{"audio": "path/to/audio1.wav", "text": "Transcript of audio 1."}
{"audio": "path/to/audio2.wav", "text": "Transcript of audio 2."}
{"audio": "path/to/audio3.wav", "text": "Optional duration field.", "duration": 3.5}
{"audio": "path/to/audio4.wav", "text": "Optional dataset_id for multi-dataset.", "dataset_id": 1}
```
### Required Fields
| Field | Description |
|-------|-------------|
| `audio` | Path to audio file (absolute or relative) |
| `text` | Corresponding transcript |
### Optional Fields
| Field | Description |
|-------|-------------|
| `duration` | Audio duration in seconds (speeds up sample filtering) |
| `dataset_id` | Dataset ID for multi-dataset training (default: 0) |
### Requirements
- Audio format: WAV
- Sample rate: 16kHz for VoxCPM-0.5B, 44.1kHz for VoxCPM1.5
- Text: Transcript matching the audio content
See `examples/train_data_example.jsonl` for a complete example.
---
## Full Fine-tuning
Full fine-tuning updates all model parameters. Suitable for large datasets or when significant behavior changes are needed.
### Configuration
Create `conf/voxcpm_v1.5/voxcpm_finetune_all.yaml`:
```yaml
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: ""
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.00001 # Use smaller LR for full fine-tuning
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192
save_path: /path/to/checkpoints/finetune_all
tensorboard: /path/to/logs/finetune_all
lambdas:
loss/diff: 1.0
loss/stop: 1.0
```
### Training
```bash
# Single GPU
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
# Multi-GPU
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
```
### Checkpoint Structure
Full fine-tuning saves a complete model directory that can be loaded directly:
```
checkpoints/finetune_all/
└── step_0002000/
├── model.safetensors # Model weights (excluding audio_vae)
├── config.json # Model config
├── audiovae.pth # Audio VAE weights
├── tokenizer.json # Tokenizer
├── tokenizer_config.json
├── special_tokens_map.json
├── optimizer.pth
└── scheduler.pth
```
---
## LoRA Fine-tuning
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that trains only a small number of additional parameters, significantly reducing memory requirements.
### Configuration
Create `conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml`:
```yaml
pretrained_path: /path/to/VoxCPM1.5/
train_manifest: /path/to/train.jsonl
val_manifest: ""
sample_rate: 44100
batch_size: 16
grad_accum_steps: 1
num_workers: 2
num_iters: 2000
log_interval: 10
valid_interval: 1000
save_interval: 1000
learning_rate: 0.0001 # LoRA can use larger LR
weight_decay: 0.01
warmup_steps: 100
max_steps: 2000
max_batch_tokens: 8192
save_path: /path/to/checkpoints/finetune_lora
tensorboard: /path/to/logs/finetune_lora
lambdas:
loss/diff: 1.0
loss/stop: 1.0
# LoRA configuration
lora:
enable_lm: true # Apply LoRA to Language Model
enable_dit: true # Apply LoRA to Diffusion Transformer
enable_proj: false # Apply LoRA to projection layers (optional)
r: 32 # LoRA rank (higher = more capacity)
alpha: 16 # LoRA alpha, scaling = alpha / r
dropout: 0.0
# Target modules
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
# Distribution options (optional)
# hf_model_id: "openbmb/VoxCPM1.5" # HuggingFace ID
# distribute: true # If true, save hf_model_id in lora_config.json
```
### LoRA Parameters
| Parameter | Description | Recommended |
|-----------|-------------|-------------|
| `enable_lm` | Apply LoRA to LM (language model) | `true` |
| `enable_dit` | Apply LoRA to DiT (diffusion model) | `true` (required for voice cloning) |
| `r` | LoRA rank (higher = more capacity) | 16-64 |
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
| `target_modules_*` | Layer names to add LoRA | attention layers |
### Distribution Options (Optional)
| Parameter | Description | Default |
|-----------|-------------|---------|
| `hf_model_id` | HuggingFace model ID (e.g., `openbmb/VoxCPM1.5`) | `""` |
| `distribute` | If `true`, save `hf_model_id` as `base_model` in checkpoint; otherwise save local `pretrained_path` | `false` |
> **Note**: If `distribute: true`, `hf_model_id` is required.
### Training
```bash
# Single GPU
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
# Multi-GPU
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
```
### Checkpoint Structure
LoRA training saves LoRA parameters and configuration:
```
checkpoints/finetune_lora/
└── step_0002000/
├── lora_weights.safetensors # Only lora_A, lora_B parameters
├── lora_config.json # LoRA config + base model path
├── optimizer.pth
└── scheduler.pth
```
The `lora_config.json` contains:
```json
{
"base_model": "/path/to/VoxCPM1.5/",
"lora_config": {
"enable_lm": true,
"enable_dit": true,
"r": 32,
"alpha": 16,
...
}
}
```
The `base_model` field contains:
- Local path (default): when `distribute: false` or not set
- HuggingFace ID: when `distribute: true` (e.g., `"openbmb/VoxCPM1.5"`)
This allows loading LoRA checkpoints without the original training config file.
---
## Inference
### Full Fine-tuning Inference
The checkpoint directory is a complete model, load it directly:
```bash
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
--text "Hello, this is the fine-tuned model." \
--output output.wav
```
With voice cloning:
```bash
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio /path/to/reference.wav \
--prompt_text "Reference audio transcript" \
--output cloned_output.wav
```
### LoRA Inference
LoRA inference only requires the checkpoint directory (base model path and LoRA config are read from `lora_config.json`):
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "Hello, this is LoRA fine-tuned result." \
--output lora_output.wav
```
With voice cloning:
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "This is voice cloning with LoRA." \
--prompt_audio /path/to/reference.wav \
--prompt_text "Reference audio transcript" \
--output cloned_output.wav
```
Override base model path (optional):
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--base_model /path/to/another/VoxCPM1.5 \
--text "Use different base model." \
--output output.wav
```
---
## LoRA Hot-swapping
LoRA supports dynamic loading, unloading, and switching at inference time without reloading the entire model.
### API Reference
```python
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# 1. Load model with LoRA structure and weights
lora_cfg = LoRAConfig(
enable_lm=True,
enable_dit=True,
r=32,
alpha=16,
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
)
model = VoxCPM.from_pretrained(
hf_model_id="openbmb/VoxCPM1.5", # or local path
load_denoiser=False, # Optional: disable denoiser for faster loading
optimize=True, # Enable torch.compile acceleration
lora_config=lora_cfg,
lora_weights_path="/path/to/lora_checkpoint",
)
# 2. Generate audio
audio = model.generate(
text="Hello, this is LoRA fine-tuned result.",
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
prompt_text="Reference audio transcript", # Optional: for voice cloning
)
# 3. Disable LoRA (use base model only)
model.set_lora_enabled(False)
# 4. Re-enable LoRA
model.set_lora_enabled(True)
# 5. Unload LoRA (reset weights to zero)
model.unload_lora()
# 6. Hot-swap to another LoRA
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
# 7. Get current LoRA weights
lora_state = model.get_lora_state_dict()
```
### Simplified Usage (Load from lora_config.json)
If your checkpoint contains `lora_config.json` (saved by the training script), you can load everything automatically:
```python
import json
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# Load config from checkpoint
lora_ckpt_dir = "/path/to/checkpoints/finetune_lora/step_0002000"
with open(f"{lora_ckpt_dir}/lora_config.json") as f:
lora_info = json.load(f)
base_model = lora_info["base_model"]
lora_cfg = LoRAConfig(**lora_info["lora_config"])
# Load model with LoRA
model = VoxCPM.from_pretrained(
hf_model_id=base_model,
lora_config=lora_cfg,
lora_weights_path=lora_ckpt_dir,
)
```
Or use the test script directly:
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "Hello world"
```
### Method Reference
| Method | Description | torch.compile Compatible |
|--------|-------------|--------------------------|
| `load_lora(path)` | Load LoRA weights from file | ✅ |
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
---
## FAQ
### 1. How Much Data is Needed for LoRA Fine-tuning to Converge to a Single Voice?
We have tested with 5 minutes and 10 minutes of data (all audio clips are 3-6s in length). In our experiments, both datasets converged to a single voice after 2000 training steps with default configurations. You can adjust the data amount and training configurations based on your available data and computational resources.
### 2. Out of Memory (OOM)
- Increase `grad_accum_steps` (gradient accumulation)
- Decrease `batch_size`
- Use LoRA fine-tuning instead of full fine-tuning
- Decrease `max_batch_tokens` to filter long samples
### 3. Poor LoRA Performance
- Increase `r` (LoRA rank)
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
- Increase training steps
- Add more target modules
### 4. Training Not Converging
- Decrease `learning_rate`
- Increase `warmup_steps`
- Check data quality
### 5. LoRA Not Taking Effect at Inference
- Check that `lora_config.json` exists in the checkpoint directory
- Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called
### 6. Checkpoint Loading Errors
- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth`
- LoRA: checkpoint directory should contain:
- `lora_weights.safetensors` (or `lora_weights.ckpt`) - LoRA weights
- `lora_config.json` - LoRA config and base model path

46
docs/performance.md Normal file
View File

@@ -0,0 +1,46 @@
# 📊 Performance Highlights
VoxCPM achieves competitive results on public zero-shot TTS benchmarks.
## Seed-TTS-eval Benchmark
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
| MaskGCT | 1B | ✅ | 2.62 | 71.7 | 2.27 | 77.4 | - | - |
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | 23.37 | 64.3 |
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | 7.12 | 75.5 |
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | 55.07 | 65.6 |
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
## CV3-eval Benchmark
| Model | zh | en | hard-zh | | | hard-en | | |
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | 8.78 | 74.5 | 3.80 |
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |

116
docs/release_note.md Normal file
View File

@@ -0,0 +1,116 @@
# VoxCPM1.5 Release Notes
**Release Date:** December 5, 2025
## 🎉 Overview
Were thrilled to introduce a major upgrade that improves audio quality and efficiency of VoxCPM, while maintaining the core capabilities of context-aware speech generation and zero-shot voice cloning.
| Feature | VoxCPM | VoxCPM1.5 |
|---------|------------|------------|
| **Audio VAE Sampling Rate** | 16kHz | 44.1kHz |
| **LM Token Rate** | 12.5Hz | 6.25Hz |
| **Patch Size** | 2 | 4 |
| **SFT Support** | ✅ | ✅ |
| **LoRA Support** | ✅ | ✅ |
## 🎵 Model Updates
### 🔊 AudioVAE Sampling Rate: 16kHz → 44.1kHz
The AudioVAE now supports 44.1kHz sampling rate, which allows the model to:
- 🎯 Clone better, preserving more high-frequency details and generate higher quality voice outputs
*Note: This upgrade enables higher quality generation when using high-quality reference audio, but does not guarantee that all generated audio will be high-fidelity. The output quality depends on the **prompt speech** quality.*
### ⚡ Token Rate: 12.5Hz → 6.25Hz
We reduced the token rate in LM backbone from 12.5Hz to 6.25Hz (LocEnc&LocDiT patch size increased from 2 to 4) while maintaining similar performance on evaluation benchmarks. This change:
- 💨 Reduces computational requirements for generating the same length of audio
- 📈 Provides a foundation for longer audio generation
- 🏗️ Paves the way for training larger models in the future
**Model Architecture Clarification**: The core architecture of VoxCPM1.5 remains unchanged from the technical report. The key modification is adjusting the patch size of the local modules (LocEnc & LocDiT) from 2 to 4, which reduces the LM processing rate from 12.5Hz to 6.25Hz. Since the local modules now need to handle longer contexts, we expanded their network depth, resulting in a slightly larger overall model parameter count.
**Generation Speed Clarification**: Although the model parameters have increased, VoxCPM1.5 only requires 6.25 tokens to generate 1 second of audio (compared to 12.5 tokens in the previous version). While the displayed generation speed (xx it/s) may appear slower, the actual Real-Time Factor (RTF = audio duration / processing time) shows no difference or may even be faster.
## 🔧 Fine-tuning Support
We support full fine-tuning and LoRA fine-tuning now, please see the [Fine-tuning Guide](finetune.md) for detailed instructions.
## 📚 Documentation
- Updated README with version comparison
- Added comprehensive fine-tuning guide
- Improved code comments and documentation
## 🙏 Our Thanks to You
This release wouldnt be possible without the incredible feedback, testing, and contributions from our open-source community. Thank you for helping shape VoxCPM1.5!
## 📞 Let's Build Together
Questions, ideas, or want to contribute?
- 🐛 Report an issue: [GitHub Issues on OpenBMB/VoxCPM](https://github.com/OpenBMB/VoxCPM/issues)
- 📖 Dig into the docs: Check the [docs/](../docs/) folder for guides and API details
Enjoy the richer sound and powerful new features of VoxCPM1.5 🎉
We can't wait to hear what you create next! 🥂
## 🚀 What We're Working On
We're continuously improving VoxCPM and working on exciting new features:
- 🌍 **Multilingual TTS Support**: We are actively developing support for languages beyond Chinese and English.
- 🎯 **Controllable Expressive Speech Generation**: We are researching controllable speech generation that allows fine-grained control over speech attributes (emotion, timbre, prosody, etc.) through natural language instructions.
- 🎵 **Universal Audio Generation Foundation**: We also hope to explore VoxCPM as a unified audio generation foundation model capable of joint generation of speech, music, and sound effects. However, this is a longer-term vision.
**📅 Next Release**: We plan to release the next version in Q1 2026, which will include significant improvements and new features. Stay tuned for updates! We're committed to making VoxCPM even more powerful and versatile.
## ❓ Frequently Asked Questions (FAQ)
### Q: Does VoxCPM support fine-tuning for personalized voice customization?
**A:** Yes! VoxCPM now supports both full fine-tuning (SFT) and efficient LoRA fine-tuning. You can train personalized voice models on your own data. Please refer to the [Fine-tuning Guide](finetune.md) for detailed instructions and examples.
### Q: Is 16kHz audio quality sufficient for my use case?
**A:** We have upgraded the AudioVAE to support 44.1kHz sampling rate in VoxCPM1.5, which provides higher quality audio output with better preservation of high-frequency details. This upgrade enables better voice cloning quality and more natural speech synthesis when using high-quality reference audio.
### Q: Has the stability issue been resolved?
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the inference code logic, training data, and model architecture. Based on community feedback, we collected some stability issues such as:
- Increased noise and reverberation
- Audio artifacts (e.g., howling/squealing)
- Unstable speaking rate (speeding up)
- Volume fluctuations (increases or decreases)
- Noise artifacts at the beginning and end of audio
- Synthesis issues with very short texts (e.g., "hello")
**What we've improved:**
- By adjusting inference code logic and optimizing training data, we have largely fixed the beginning/ending artifacts.
- By reducing the LM processing rate (12.5Hz → 6.25Hz), we have improved stability on longer speech generation cases.
**What remains:** We acknowledge that long speech stability issues have not been completely resolved. Particularly for highly expressive or complex reference speech, error accumulation during autoregressive generation can still occur. We will continue to analyze and optimize this in future versions.
### Q: Does VoxCPM plan to support multilingual TTS?
**A:** Currently, VoxCPM is primarily trained on Chinese and English data. We are actively researching and developing multilingual TTS support for more languages beyond Chinese and English. Please let us know what languages you'd like to see supported!
### Q: Does VoxCPM plan to support controllable generation (emotion, style, fine-grained control)?
**A:** Currently, VoxCPM only supports zero-shot voice cloning and context-aware speech generation. Direct control over specific speech attributes (emotion, style, fine-grained prosody) is limited. However, we are actively researching instruction-controllable expressive speech generation with fine-grained control capabilities, working towards a human instruction-to-speech generation model!
### Q: Does VoxCPM support different hardware chips (e.g., Ascend 910B, XPU, NPU)?
**A:** Currently, we have not yet adapted VoxCPM for different hardware chips. Our main focus remains on developing new model capabilities and improving stability. We encourage you to check if community developers have done similar work, and we warmly welcome everyone to contribute and promote such adaptations together!
These features are under active development, and we look forward to sharing updates in future releases!

55
docs/usage_guide.md Normal file
View File

@@ -0,0 +1,55 @@
# 👩‍🍳 A Voice Chef's Guide
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Let's begin.
---
## 🥚 Step 1: Prepare Your Base Ingredients (Content)
First, choose how you'd like to input your text:
### 1. Regular Text (Classic Mode)
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
### 2. Phoneme Input (Native Mode)
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like `{HH AH0 L OW1}` (EN) or `{ni3}{hao3}` (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
- **Phoneme Conversion**: For Chinese, phonemes are converted using pinyin. For English, phonemes are converted using CMUDict. Please refer to the relevant documentation for more details.
---
## 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
This is the secret sauce that gives your audio its unique sound.
### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
- **For a Clean, Denoising Voice:**
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling.
- **For High-Quality Audio Cloning (Up to 44.1kHz):**
- ❌ Disable "Prompt Speech Enhancement" to preserve all original audio information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
### 2. Cooking au Naturel (Letting the Model Improvise)
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
- **Pro Tip**: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
---
## 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
### CFG Value (How Closely to Follow the Recipe)
- **Default**: A great starting point.
- **Voice sounds strained or weird?** Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
- **Need maximum clarity and adherence to the text?** Raise it slightly to keep the model on a tighter leash.
- **Short sentences?** Consider increasing the CFG value for better clarity and adherence.
- **Long texts?** Consider lowering the CFG value to improve stability and naturalness over extended passages.
### Inference Timesteps (Simmering Time: Quality vs. Speed)
- **Need a quick snack?** Use a lower number. Perfect for fast drafts and experiments.
- **Cooking a gourmet meal?** Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
---
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!

View File

@@ -0,0 +1,6 @@
{"audio": "examples/example.wav", "text": "This is an example audio transcript for training."}
{"audio": "/absolute/path/to/audio1.wav", "text": "You can use absolute paths for audio files."}
{"audio": "relative/path/to/audio2.wav", "text": "Or relative paths from the working directory."}
{"audio": "data/audio3.wav", "text": "Each line is a JSON object with audio path and text.", "duration": 3.5}
{"audio": "data/audio4.wav", "text": "Optional: add duration field to skip audio loading during filtering.", "duration": 2.8}
{"audio": "data/audio5.wav", "text": "Optional: add dataset_id for multi-dataset training.", "dataset_id": 1}

1253
lora_ft_webui.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -29,7 +29,7 @@ dependencies = [
"torchaudio>=2.5.0",
"transformers>=4.36.2",
"einops",
"gradio",
"gradio<6",
"inflect",
"addict",
"wetext",
@@ -42,7 +42,10 @@ dependencies = [
"sortedcontainers",
"soundfile",
"funasr",
"spaces"
"spaces",
"argbind",
"safetensors"
]
[project.optional-dependencies]

View File

@@ -0,0 +1,131 @@
#!/usr/bin/env python3
"""
Full finetune inference script (no LoRA).
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
can be loaded directly via VoxCPM.
Usage:
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/step_0001000 \
--text "Hello, I am the finetuned VoxCPM." \
--output ft_test.wav
With voice cloning:
python scripts/test_voxcpm_ft_infer.py \
--ckpt_dir /path/to/checkpoints/step_0001000 \
--text "Hello, this is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output ft_clone.wav
"""
import argparse
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
def parse_args():
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
parser.add_argument(
"--ckpt_dir",
type=str,
required=True,
help="Checkpoint directory (contains pytorch_model.bin, config.json, audiovae.pth, etc.)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Target text to synthesize",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="Optional: reference audio path for voice cloning",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="Optional: transcript of reference audio",
)
parser.add_argument(
"--output",
type=str,
default="ft_test.wav",
help="Output wav file path",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="CFG scale (default: 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="Diffusion inference steps (default: 10)",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
def main():
args = parse_args()
# Load model from checkpoint directory (no denoiser)
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=args.ckpt_dir,
load_denoiser=False,
optimize=True,
)
# Run inference
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
print(f"[FT Inference] Synthesizing: text='{args.text}'")
if prompt_wav_path:
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
print(f"[FT Inference] Reference text: {prompt_text}")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
# Save audio
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,244 @@
#!/usr/bin/env python3
"""
LoRA inference test script.
Usage:
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt checkpoints/step_0002000 \
--text "Hello, this is LoRA finetuned result." \
--output lora_test.wav
With voice cloning:
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt checkpoints/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output lora_clone.wav
Note: The script reads base_model path and lora_config from lora_config.json
in the checkpoint directory (saved automatically during training).
"""
import argparse
import json
from pathlib import Path
import soundfile as sf
from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--lora_ckpt",
type=str,
required=True,
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
)
parser.add_argument(
"--base_model",
type=str,
default="",
help="Optional: override base model path (default: read from lora_config.json)",
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Target text to synthesize",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="Optional: reference audio path for voice cloning",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="Optional: transcript of reference audio",
)
parser.add_argument(
"--output",
type=str,
default="lora_test.wav",
help="Output wav file path",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="CFG scale (default: 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="Diffusion inference steps (default: 10)",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="Max generation steps",
)
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
return parser.parse_args()
def main():
args = parse_args()
# 1. Check LoRA checkpoint directory
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
# 2. Load lora_config.json from checkpoint
lora_config_path = ckpt_dir / "lora_config.json"
if not lora_config_path.exists():
raise FileNotFoundError(
f"lora_config.json not found in {ckpt_dir}. "
"Make sure the checkpoint was saved with the updated training script."
)
with open(lora_config_path, "r", encoding="utf-8") as f:
lora_info = json.load(f)
# Get base model path (command line arg overrides config)
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
if not pretrained_path:
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
# Get LoRA config
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}")
print(f" Base model: {pretrained_path}")
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}")
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=True,
lora_config=lora_cfg,
lora_weights_path=str(ckpt_dir),
)
# 4. Synthesize audio
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f"\n[2/2] Starting synthesis tests...")
# === Test 1: With LoRA ===
print(f"\n [Test 1] Synthesize with LoRA...")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 2: Disable LoRA (via set_lora_enabled) ===
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
model.set_lora_enabled(False)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 3: Re-enable LoRA ===
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
model.set_lora_enabled(True)
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 4: Unload LoRA (reset_lora_weights) ===
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
model.unload_lora()
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
# === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters")
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
normalize=args.normalize,
denoise=False,
)
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
print(f"\n[Done] All tests completed!")
print(f" - with_lora: {lora_output}")
print(f" - lora_disabled: {disabled_output}")
print(f" - lora_reenabled: {reenabled_output}")
print(f" - lora_reset: {reset_output}")
print(f" - lora_reloaded: {reload_output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,510 @@
#!/usr/bin/env python3
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "src"))
import contextlib
from typing import Dict, Optional
import argbind
import torch
from tensorboardX import SummaryWriter
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
import signal
import os
try:
from safetensors.torch import save_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
print("Warning: safetensors not available, will use pytorch format")
from voxcpm.model import VoxCPMModel
from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training import (
Accelerator,
BatchProcessor,
TrainingTracker,
build_dataloader,
load_audio_text_datasets,
)
@argbind.bind(without_prefix=True)
def train(
pretrained_path: str,
train_manifest: str,
val_manifest: str = "",
sample_rate: int = 16_000,
batch_size: int = 1,
grad_accum_steps: int = 1,
num_workers: int = 2,
num_iters: int = 100_000,
log_interval: int = 100,
valid_interval: int = 1_000,
save_interval: int = 10_000,
learning_rate: float = 1e-4,
weight_decay: float = 1e-2,
warmup_steps: int = 1_000,
max_steps: int = 100_000,
max_batch_tokens: int = 0,
save_path: str = "checkpoints",
tensorboard: str = "",
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None,
config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
):
_ = config_path
# Validate distribution options
if lora is not None and distribute and not hf_model_id:
raise ValueError("hf_model_id is required when distribute=True")
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs"
# Only create directories on rank 0 to avoid race conditions
if accelerator.rank == 0:
save_dir.mkdir(parents=True, exist_ok=True)
tb_dir.mkdir(parents=True, exist_ok=True)
accelerator.barrier() # Wait for directory creation
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
tokenizer = base_model.text_tokenizer
train_ds, val_ds = load_audio_text_datasets(
train_manifest=train_manifest,
val_manifest=val_manifest,
sample_rate=sample_rate,
)
def tokenize(batch):
text_list = batch["text"]
text_ids = [tokenizer(text) for text in text_list]
return {"text_ids": text_ids}
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
if val_ds is not None:
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
num_train_samples = len(train_ds)
# ------------------------------------------------------------------ #
# Optional: filter samples by estimated token count to avoid OOM
# Enabled when max_batch_tokens > 0:
# max_sample_len = max_batch_tokens // batch_size
# Samples exceeding this length will be dropped
# ------------------------------------------------------------------ #
if max_batch_tokens and max_batch_tokens > 0:
from voxcpm.training.data import compute_sample_lengths
audio_vae_fps = base_model.audio_vae.sample_rate / base_model.audio_vae.hop_length
est_lengths = compute_sample_lengths(
train_ds,
audio_vae_fps=audio_vae_fps,
patch_size=base_model.config.patch_size,
)
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths)
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len]
if len(keep_indices) < len(train_ds) and accelerator.rank == 0:
tracker.print(
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} "
f"training samples longer than {max_sample_len} tokens "
f"(max_batch_tokens={max_batch_tokens})."
)
train_ds = train_ds.select(keep_indices)
train_loader = build_dataloader(
train_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
)
val_loader = (
build_dataloader(
val_ds,
accelerator=accelerator,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
if val_ds is not None
else None
)
batch_processor = BatchProcessor(
config=base_model.config,
audio_vae=base_model.audio_vae,
dataset_cnt=dataset_cnt,
device=accelerator.device,
)
del base_model.audio_vae
model = accelerator.prepare_model(base_model)
unwrapped_model = accelerator.unwrap(model)
unwrapped_model.train()
# Only print param info on rank 0 to avoid cluttered output
if accelerator.rank == 0:
for name, param in model.named_parameters():
print(name, param.requires_grad)
optimizer = AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=learning_rate,
weight_decay=weight_decay,
)
# Cosine + warmup scheduler from transformers:
# - num_warmup_steps: warmup steps
# - num_training_steps: total training steps (outer step count)
total_training_steps = max_steps if max_steps > 0 else num_iters
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_training_steps,
)
# Try to load checkpoint and resume training
start_step = 0
if accelerator.rank == 0:
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
# Resume tracker for signal handler to read current step
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.")
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
grad_accum_steps = max(int(grad_accum_steps), 1)
data_epoch = 0
train_iter = iter(train_loader)
def get_next_batch():
"""Get next batch, handles epoch boundary and DistributedSampler."""
nonlocal train_iter, data_epoch
try:
return next(train_iter)
except StopIteration:
data_epoch += 1
# Key: set DistributedSampler epoch to ensure different data order each epoch
sampler = getattr(train_loader, 'sampler', None)
if hasattr(sampler, 'set_epoch'):
sampler.set_epoch(data_epoch)
train_iter = iter(train_loader)
return next(train_iter)
with tracker.live():
for step in range(start_step, num_iters):
# update resume step so signal handler can save current progress
resume["step"] = step
tracker.step = step
optimizer.zero_grad(set_to_none=True)
# Gradient accumulation: accumulate gradients over micro-batches before optimizer step
loss_dict = {}
for micro_step in range(grad_accum_steps):
batch = get_next_batch()
processed = batch_processor(batch)
# Only sync gradients on the last micro-batch
# Use no_sync() for intermediate steps to reduce communication overhead
is_last_micro_step = (micro_step == grad_accum_steps - 1)
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
with sync_context:
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=step / max(1, num_iters),
)
total_loss = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
weight = lambdas.get(key, 1.0)
loss_value = value * weight / grad_accum_steps
total_loss = total_loss + loss_value
# Record raw loss from last micro-batch for logging
loss_dict[key] = value.detach()
# Accumulate gradients (normalized by grad_accum_steps)
accelerator.backward(total_loss)
# After all micro-batches, do unscale / grad_norm / step
scaler = getattr(accelerator, "scaler", None)
if scaler is not None:
scaler.unscale_(optimizer)
# Use large max_norm to only compute grad_norm without actual clipping
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
accelerator.step(optimizer)
accelerator.update()
scheduler.step()
if step % log_interval == 0:
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
loss_values["epoch"] = float(epoch)
loss_values["grad_norm"] = float(grad_norm)
tracker.log_metrics(loss_values, split="train")
if val_loader is not None and step % valid_interval == 0 and step != 0:
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if step % save_interval == 0 and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
if accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute)
if writer:
writer.close()
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.eval()
losses = []
num_batches = 0
max_val_batches = 10
with torch.no_grad():
for batch in val_loader:
if num_batches >= max_val_batches:
break
processed = batch_processor(batch)
with accelerator.autocast(dtype=torch.bfloat16):
outputs = model(
processed["text_tokens"],
processed["text_mask"],
processed["audio_feats"],
processed["audio_mask"],
processed["loss_mask"],
processed["position_ids"],
processed["labels"],
progress=0.0,
sample_generate=False,
)
total = 0.0
for key, value in outputs.items():
if key.startswith("loss/"):
total += lambdas.get(key, 1.0) * value
losses.append(total.detach())
num_batches += 1
if losses:
mean_loss = torch.stack(losses).mean()
# All-reduce validation loss across processes for global average
accelerator.all_reduce(mean_loss)
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
model.train()
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
"""
Load the latest checkpoint if it exists.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config
# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")
# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
return resume_step
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
"""
import shutil
save_dir.mkdir(parents=True, exist_ok=True)
tag = "latest" if step == 0 else f"step_{step:07d}"
folder = save_dir / tag
folder.mkdir(parents=True, exist_ok=True)
unwrapped = model.module if hasattr(model, "module") else model
full_state = unwrapped.state_dict()
lora_cfg = unwrapped.lora_config
if lora_cfg is not None:
# LoRA finetune: save only lora_A/lora_B weights
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
if SAFETENSORS_AVAILABLE:
save_file(state_dict, folder / "lora_weights.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
import json
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
lora_info = {
"base_model": base_model_to_save,
"lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg),
}
with open(folder / "lora_config.json", "w", encoding="utf-8") as f:
json.dump(lora_info, f, indent=2, ensure_ascii=False)
else:
# Full finetune: save non-vae weights to model.safetensors
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
if SAFETENSORS_AVAILABLE:
save_file(state_dict, folder / "model.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
# Copy config files from pretrained path
if pretrained_path:
pretrained_dir = Path(pretrained_path)
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
for fname in files_to_copy:
src = pretrained_dir / fname
if src.exists():
shutil.copy2(src, folder / fname)
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
latest_link = save_dir / "latest"
try:
if latest_link.exists() or latest_link.is_symlink():
# remove existing link or directory
if latest_link.is_dir() and not latest_link.is_symlink():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
# Create a symlink pointing to the new folder
os.symlink(str(folder), str(latest_link))
except Exception:
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
try:
if latest_link.exists():
if latest_link.is_dir():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config
args = argbind.parse_args()
config_file = args.get("config_path")
# If YAML config provided, use YAML args to call train
if config_file:
yaml_args = load_yaml_config(config_file)
train(**yaml_args)
else:
# Otherwise use command line args (parsed by argbind)
with argbind.scope(args):
train()

View File

@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
"ZIPENHANCER_MODEL_PATH", None
)
# Build LoRA config if lora_path is provided
lora_config = None
lora_weights_path = getattr(args, "lora_path", None)
if lora_weights_path:
from voxcpm.model.voxcpm import LoRAConfig
lora_config = LoRAConfig(
enable_lm=getattr(args, "lora_enable_lm", True),
enable_dit=getattr(args, "lora_enable_dit", True),
enable_proj=getattr(args, "lora_enable_proj", False),
r=getattr(args, "lora_r", 32),
alpha=getattr(args, "lora_alpha", 16),
dropout=getattr(args, "lora_dropout", 0.0),
)
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
# Load from local path if provided
if getattr(args, "model_path", None):
try:
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
voxcpm_model_path=args.model_path,
zipenhancer_model_path=zipenhancer_path,
enable_denoiser=not getattr(args, "no_denoiser", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (local).")
return model
@@ -69,11 +87,13 @@ def load_model(args) -> VoxCPM:
# Otherwise, try from_pretrained (Hub); exit on failure
try:
model = VoxCPM.from_pretrained(
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
load_denoiser=not getattr(args, "no_denoiser", False),
zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None),
local_files_only=getattr(args, "local_files_only", False),
lora_config=lora_config,
lora_weights_path=lora_weights_path,
)
print("Model loaded (from_pretrained).")
return model
@@ -120,11 +140,11 @@ def cmd_clone(args):
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
@@ -152,11 +172,11 @@ def cmd_synthesize(args):
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
@@ -198,9 +218,9 @@ def cmd_batch(args):
denoise=args.denoise and prompt_audio_path is not None
)
output_file = output_dir / f"output_{i:03d}.wav"
sf.write(str(output_file), audio_array, 16000)
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)")
success_count += 1
@@ -250,12 +270,21 @@ Examples:
# Model loading parameters
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
# LoRA parameters
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
return parser

View File

@@ -2,9 +2,9 @@ import os
import re
import tempfile
import numpy as np
from typing import Generator
from typing import Generator, Optional
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel
from .model.voxcpm import VoxCPMModel, LoRAConfig
class VoxCPM:
def __init__(self,
@@ -12,6 +12,8 @@ class VoxCPM:
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
"""Initialize VoxCPM TTS pipeline.
@@ -23,28 +25,53 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
enable_lm=True,
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}")
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
self.denoiser = ZipEnhancer(zipenhancer_model_path)
else:
self.denoiser = None
print("Warm up VoxCPMModel...")
self.tts_model.generate(
target_text="Hello, this is the first test sentence.",
max_len=10,
)
if optimize:
print("Warm up VoxCPMModel...")
self.tts_model.generate(
target_text="Hello, this is the first test sentence.",
max_len=10,
)
@classmethod
def from_pretrained(cls,
hf_model_id: str = "openbmb/VoxCPM-0.5B",
hf_model_id: str = "openbmb/VoxCPM1.5",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -52,11 +79,18 @@ class VoxCPM:
Args:
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
load_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
zipenhancer_model_id: Denoiser model id or path for ModelScope
acoustic noise suppression.
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
after model initialization.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
@@ -87,6 +121,9 @@ class VoxCPM:
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,
)
@@ -102,9 +139,10 @@ class VoxCPM:
prompt_text : str = None,
cfg_value : float = 2.0,
inference_timesteps : int = 10,
max_length : int = 4096,
normalize : bool = True,
denoise : bool = True,
min_len : int = 2,
max_len : int = 4096,
normalize : bool = False,
denoise : bool = False,
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
@@ -124,7 +162,7 @@ class VoxCPM:
prompt_text: Text content corresponding to the prompt audio.
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
max_length: Maximum token length during generation.
max_len: Maximum token length during generation.
normalize: Whether to run text normalization before generation.
denoise: Whether to denoise the prompt audio if a denoiser is
available.
@@ -174,8 +212,8 @@ class VoxCPM:
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=2,
max_len=max_length,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
@@ -192,4 +230,52 @@ class VoxCPM:
try:
os.unlink(temp_prompt_wav_path)
except OSError:
pass
pass
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
if self.tts_model.lora_config is None:
raise RuntimeError(
"Cannot load LoRA weights: model was not initialized with LoRA config. "
"Please reinitialize with lora_config or lora_weights_path parameter."
)
return self.tts_model.load_lora_weights(lora_weights_path)
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
return self.tts_model.lora_config is not None

View File

@@ -19,19 +19,27 @@ limitations under the License.
"""
import os
from typing import Tuple, Union, Generator, List
from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.audiovae import AudioVAE, AudioVAEConfig
from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
@@ -66,10 +74,31 @@ class VoxCPMConfig(BaseModel):
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
audio_vae_config: Optional[AudioVAEConfig] = None
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
dit_mean_mode: bool = False
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
dropout: float = 0.0
# Target linear layer names for LM & DiT (matched by attribute name)
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
# Projection layer attribute names to find on VoxCPMModel
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPMModel(nn.Module):
@@ -78,18 +107,24 @@ class VoxCPMModel(nn.Module):
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
self.device = "cpu"
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
# Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101
@@ -100,7 +135,7 @@ class VoxCPMModel(nn.Module):
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder
encoder_config = config.lm_config.model_copy(deep=True)
@@ -124,6 +159,7 @@ class VoxCPMModel(nn.Module):
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_mean_mode,
)
# Projection layers
@@ -132,7 +168,7 @@ class VoxCPMModel(nn.Module):
config.lm_config.hidden_size,
config.scalar_quantization_latent_dim,
config.scalar_quantization_scale
)
)
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -141,17 +177,46 @@ class VoxCPMModel(nn.Module):
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
if self.lora_config is not None:
self._apply_lora()
def _apply_lora(self):
"""注入 LoRA 到 LM / DiT / 投影层"""
cfg = self.lora_config
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
)
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
def optimize(self, disable: bool = False):
if disable:
return self
try:
if disable:
raise ValueError("Optimization disabled by user")
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
@@ -160,17 +225,111 @@ class VoxCPMModel(nn.Module):
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Error: {e}")
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
print(f"Warning: torch.compile disabled - {e}")
return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -234,25 +393,25 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
@@ -271,7 +430,7 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text))
@@ -284,7 +443,7 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -310,7 +469,6 @@ class VoxCPMModel(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode()
@@ -327,13 +485,11 @@ class VoxCPMModel(nn.Module):
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
@@ -346,7 +502,9 @@ class VoxCPMModel(nn.Module):
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# extract audio features
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
@@ -356,10 +514,9 @@ class VoxCPMModel(nn.Module):
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
# build prompt cache - only save raw text and audio features
prompt_cache = {
"text_token": text_token,
"prompt_text": prompt_text,
"audio_feat": audio_feat,
}
@@ -369,7 +526,7 @@ class VoxCPMModel(nn.Module):
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_text: str,
new_audio_feat: torch.Tensor,
):
"""
@@ -377,38 +534,42 @@ class VoxCPMModel(nn.Module):
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
merged_cache: merged cache with prompt_text and audio_feat
"""
if original_cache is None:
return {
"text_token": new_text_token,
"prompt_text": new_text,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_prompt_text = original_cache["prompt_text"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
# Merge text by concatenation
merged_prompt_text = original_prompt_text + new_text
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -449,14 +610,14 @@ class VoxCPMModel(nn.Module):
retry_badcase = False
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
text = target_text
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
prompt_text = prompt_cache["prompt_text"]
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
@@ -468,6 +629,8 @@ class VoxCPMModel(nn.Module):
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
@@ -484,7 +647,7 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference
@@ -497,7 +660,7 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -526,7 +689,6 @@ class VoxCPMModel(nn.Module):
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield (
decode_audio,
@@ -552,6 +714,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps: int = 10,
cfg_value: float = 2.0,
streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
@@ -624,7 +787,7 @@ class VoxCPMModel(nn.Module):
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
@@ -632,8 +795,9 @@ class VoxCPMModel(nn.Module):
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
@@ -652,35 +816,138 @@ class VoxCPMModel(nn.Module):
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True):
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae)
lm_dtype = get_dtype(config.dtype)
model = model.to(lm_dtype)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}")
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",
weights_only=True,
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA Weight Management
# ------------------------------------------------------------------ #
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
def load_lora_weights(self, lora_path: str, device: str = None):
"""
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
if target_key:
model_params[target_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""Enable/disable all LoRA layers."""
for module in self._iter_lora_modules():
module.set_enabled(enabled)
def reset_lora_weights(self):
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
for module in self._iter_lora_modules():
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}

View File

@@ -1 +1 @@
from .audio_vae import AudioVAE
from .audio_vae import AudioVAE, AudioVAEConfig

View File

@@ -1,11 +1,12 @@
import math
from typing import List, Union
from typing import List, Union, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 1536
decoder_rates: List[int] = [8, 8, 5, 2]
depthwise: bool = True
sample_rate: int = 16000
use_noise_block: bool = False
class AudioVAE(nn.Module):
"""
Args:
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 5, 8, 8],
latent_dim: int = 64,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 5, 2],
depthwise: bool = True,
sample_rate: int = 16000,
use_noise_block: bool = False,
config: Optional[AudioVAEConfig] = None,
):
# 如果没有传入config使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
use_noise_block = config.use_noise_block
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim

View File

@@ -0,0 +1,133 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRALinear(nn.Module):
"""
LoRA 线性层:直接持有 weight/bias保持与 nn.Linear 相同的 state_dict key 结构。
state_dict 结构:
- weight: 原始权重(与 nn.Linear 一致)
- bias: 原始偏置(与 nn.Linear 一致)
- lora_A: LoRA 低秩矩阵 A
- lora_B: LoRA 低秩矩阵 B
这样设计的好处:加载预训练权重时无需做 key 转换。
"""
def __init__(
self,
base: nn.Linear,
r: int,
alpha: float = 1.0,
dropout: float = 0.0,
):
super().__init__()
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
self.in_features = base.in_features
self.out_features = base.out_features
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
# 直接持有 weight 和 bias从原始 Linear 转移过来)
self.weight = base.weight
self.bias = base.bias # 可能是 None
# LoRA 参数
if r > 0:
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
else:
self.register_parameter("lora_A", None)
self.register_parameter("lora_B", None)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 基础 Linear 计算
result = F.linear(x, self.weight, self.bias)
if self.r <= 0 or self.lora_A is None:
return result
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
return result + self.dropout(lora_out) * self.scaling
def reset_lora_parameters(self):
"""重置 LoRA 参数到初始状态"""
if self.r > 0 and self.lora_A is not None:
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def set_enabled(self, enabled: bool):
"""启用/禁用 LoRA通过 scaling 控制,兼容 torch.compile"""
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
self.scaling.fill_(self._base_scaling if enabled else 0.0)
@property
def enabled(self) -> bool:
return self.scaling.item() != 0.0
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
"""
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module即 q_proj 的上一级)。
"""
parts = name.split(".")
if len(parts) == 1:
return root
parent = root
for p in parts[:-1]:
if not hasattr(parent, p):
return None
parent = getattr(parent, p)
return parent
def apply_lora_to_named_linear_modules(
root: nn.Module,
*,
target_submodule_names: list[str],
r: int,
alpha: float,
dropout: float,
) -> None:
"""
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
例如 target_submodule_names=["q_proj", "v_proj"] 时,
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
"""
for full_name, module in list(root.named_modules()):
if not isinstance(module, nn.Linear):
continue
short_name = full_name.split(".")[-1]
if short_name not in target_submodule_names:
continue
parent = _get_parent_module(root, full_name)
if parent is None:
continue
# 用 LoRALinear 替换原始 Linear
lora_layer = LoRALinear(
base=module,
r=r,
alpha=alpha,
dropout=dropout,
)
setattr(parent, short_name, lora_layer)

View File

@@ -1,20 +1,29 @@
from typing import List, Tuple
import torch
from typing import List
from .local_dit import VoxCPMLocDiT
import math
import torch.nn.functional as F
from torch.func import jvp
from pydantic import BaseModel
from .local_dit import VoxCPMLocDiT
class CfmConfig(BaseModel):
sigma_min: float = 1e-06
sigma_min: float = 1e-6
solver: str = "euler"
t_scheduler: str = "log-norm"
training_cfg_rate: float = 0.1
inference_cfg_rate: float = 1.0
reg_loss_type: str = "l1"
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
noise_cond_scale: float = 0.0
class UnifiedCFM(torch.nn.Module):
def __init__(
self,
in_channels,
in_channels: int,
cfm_params: CfmConfig,
estimator: VoxCPMLocDiT,
mean_mode: bool = False,
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
self.solver = cfm_params.solver
self.sigma_min = cfm_params.sigma_min
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
self.reg_loss_type = cfm_params.reg_loss_type
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
self.noise_cond_scale = cfm_params.noise_cond_scale
self.in_channels = in_channels
self.mean_mode = mean_mode
# Just change the architecture of the estimator here
self.estimator = estimator
# ------------------------------------------------------------------ #
# Inference
# ------------------------------------------------------------------ #
@torch.inference_mode()
def forward(
self,
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
n_timesteps (int): number of diffusion steps
cond: Not used but kept for future purposes
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
b, c = mu.shape
b, _ = mu.shape
t = patch_size
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# Sway sampling strategy
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
return self.solve_euler(
x=z,
t_span=t_span,
mu=mu,
cond=cond,
cfg_value=cfg_value,
use_cfg_zero_star=use_cfg_zero_star,
)
def optimized_scale(self, positive_flat, negative_flat):
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
st_star = dot_product / squared_norm
return st_star
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
cfg_value: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
cond: condition -- prefix prompt
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
"""
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
sol = []
zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = 0.
dphi_dt = torch.zeros_like(x)
else:
# Classifier-Free Guidance inference introduced in VoiceBox
b = x.size(0)
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
x_in[:b], x_in[b:] = x, x
mu_in[:b] = mu
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
dt = t - t_span[step + 1]
return sol[-1]
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
return weights.detach()
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
batch_size = x.shape[0]
if self.t_scheduler == "log-norm":
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
r = torch.sigmoid(s_r)
t = torch.sigmoid(s_t)
elif self.t_scheduler == "uniform":
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
else:
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
r, t = torch.where(
mask,
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
torch.stack([t, t], dim=0),
)
return r.squeeze(), t.squeeze()
def compute_loss(
self,
x1: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
progress: float = 0.0,
):
b, _, _ = x1.shape
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1)
if cond is None:
cond = torch.zeros_like(x1)
noisy_mask = torch.rand(b, device=x1.device) > (
1.0
- (
self.noise_cond_prob_range[0]
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
)
)
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
r_ = r.detach().clone()
t_ = t.detach().clone()
z = torch.randn_like(x1)
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
v = z - x1
def model_fn(z_sample, r_sample, t_sample):
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
if self.mean_mode:
v_r = torch.zeros_like(r)
v_t = torch.ones_like(t)
from torch.backends.cuda import sdp_kernel
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
else:
u_pred = model_fn(y, r, t)
u_tgt = v
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask)
else:
loss = losses.mean()
return loss

View File

@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_cache = key_cache.contiguous()
value_cache = value_cache.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_cache,

View File

@@ -0,0 +1,28 @@
"""
Training utilities for VoxCPM fine-tuning.
This package mirrors the training mechanics used in the minicpm-audio
tooling while relying solely on local audio-text datasets managed via
the HuggingFace ``datasets`` library.
"""
from .accelerator import Accelerator
from .tracker import TrainingTracker
from .data import (
load_audio_text_datasets,
HFVoxCPMDataset,
build_dataloader,
BatchProcessor,
)
from .state import TrainingState
__all__ = [
"Accelerator",
"TrainingTracker",
"HFVoxCPMDataset",
"BatchProcessor",
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
]

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
import contextlib
import os
import random
import typing
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
class Accelerator:
"""
Simplified accelerator that mirrors the behaviour of the minicpm-audio
training utilities. It initializes a distributed process group when
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
preparing models/dataloaders for DDP.
"""
def __init__(self, amp: bool = False, seed: int = 42):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size > 1 and not dist.is_initialized():
dist.init_process_group("nccl", init_method="env://")
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.amp = amp
# Set random seed to ensure model initialization consistency
self._set_seed(seed)
class DummyScaler:
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def barrier(self):
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
"""All-reduce tensor across processes"""
if dist.is_initialized():
dist.all_reduce(tensor, op=op)
return tensor
# ------------------------------------------------------------------ #
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
self._ddp_model = model # Save DDP model reference for no_sync support
return model
@contextlib.contextmanager
def no_sync(self):
"""
Context manager to skip gradient synchronization during gradient accumulation.
Only used outside the last micro-batch.
"""
if self._ddp_model is not None:
with self._ddp_model.no_sync():
yield
else:
yield
@property
def device(self):
if torch.cuda.is_available():
return torch.device("cuda", self.local_rank)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ------------------------------------------------------------------ #
# AMP helpers
# ------------------------------------------------------------------ #
def autocast(self, *args, **kwargs):
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
self.scaler.step(optimizer)
def update(self):
self.scaler.update()
# ------------------------------------------------------------------ #
# Data helpers
# ------------------------------------------------------------------ #
def prepare_dataloader(
self,
dataset: typing.Iterable,
*,
batch_size: int,
num_workers: int = 0,
shuffle: bool = True,
collate_fn=None,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
if self.world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
)
shuffle = False
else:
sampler = None
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
pin_memory=True,
)
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import argbind
import yaml
from pathlib import Path
from typing import Dict, Any
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
"""
Load a YAML configuration file into a dictionary suitable for argbind.
"""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
return data
def parse_args_with_config(config_path: str | Path | None = None):
"""
Helper to unify CLI arguments and YAML configuration.
Usage mirrors minicpm-audio:
args = parse_args_with_config("conf/voxcpm/finetune.yml")
with argbind.scope(args):
...
"""
cli_args = argbind.parse_args()
if config_path is None:
return cli_args
yaml_args = load_yaml_config(config_path)
with argbind.scope(cli_args):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args

214
src/voxcpm/training/data.py Normal file
View File

@@ -0,0 +1,214 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
import torch
from datasets import Audio, Dataset, DatasetDict, load_dataset
from torch.utils.data import Dataset as TorchDataset
from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@argbind.bind()
def load_audio_text_datasets(
train_manifest: str,
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
) -> Tuple[Dataset, Optional[Dataset]]:
data_files = {"train": train_manifest}
if val_manifest:
data_files["validation"] = val_manifest
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
else:
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
return ds
train_ds = prepare(dataset_dict["train"])
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
return train_ds, val_ds
def compute_sample_lengths(
ds: Dataset,
audio_vae_fps: int = 25,
patch_size: int = 1,
) -> List[int]:
"""
预估每个样本经过 packer 之后的大致序列长度text+audio用于过滤超长样本。
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
- 文本长度: len(text_ids)
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
"""
lengths: List[int] = []
has_duration = "duration" in ds.column_names
for i in range(len(ds)):
item = ds[i]
text_len = len(item["text_ids"])
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
if has_duration:
duration = float(item["duration"])
else:
audio = item[DEFAULT_AUDIO_COLUMN]
duration = len(audio["array"]) / float(audio["sampling_rate"])
t_vae = math.ceil(duration * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
lengths.append(total_len)
return lengths
class HFVoxCPMDataset(TorchDataset):
"""
Thin wrapper around a tokenized HuggingFace dataset that returns
PyTorch-friendly samples.
"""
def __init__(self, dataset: Dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
if not seqs:
return torch.empty(0)
max_len = max(seq.shape[0] for seq in seqs)
padded = []
for seq in seqs:
if seq.shape[0] < max_len:
pad_width = (0, max_len - seq.shape[0])
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
padded.append(seq)
return torch.stack(padded)
@classmethod
def collate_fn(cls, batch: List[Dict]):
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
"dataset_ids": dataset_ids,
"is_prompts": is_prompts,
}
class BatchProcessor:
"""
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
the minicpm-audio mechanics.
"""
def __init__(
self,
*,
config: VoxCPMConfig,
audio_vae: AudioVAE,
dataset_cnt: int,
device: torch.device,
):
self.device = device
self.dataset_cnt = dataset_cnt
self.audio_vae = audio_vae
self.audio_vae.to(device)
self.packer = AudioFeatureProcessingPacker(
dataset_cnt=dataset_cnt,
max_len=config.max_length,
patch_size=config.patch_size,
feat_dim=config.feat_dim,
audio_vae=self.audio_vae,
)
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
audio_tokens = batch["audio_tokens"].to(self.device)
text_tokens = batch["text_tokens"].to(self.device)
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
)
return packed
def build_dataloader(
hf_dataset: Dataset,
*,
accelerator,
batch_size: int,
num_workers: int,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
torch_dataset = HFVoxCPMDataset(hf_dataset)
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
return accelerator.prepare_dataloader(
torch_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last,
)

View File

@@ -0,0 +1,289 @@
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from einops import rearrange
class AudioFeatureProcessingPacker:
"""
Adapted from the minicpm-audio training utilities. It converts raw text and
audio tokens into the packed multimodal representation required by VoxCPM.
"""
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
self.patch_size = patch_size
self.patch_len = audio_vae.hop_length * self.patch_size
self.feat_dim = feat_dim
self.dataset_cnt = max(dataset_cnt, 1)
self.max_len = max_len
self.audio_vae = audio_vae
self.process_functions = {"tts": self.process_tts_data}
self.task_id_map = {"tts": 1}
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _first_pad_position(tokens: torch.Tensor):
positions = (tokens == -100).nonzero(as_tuple=True)
if positions[0].numel() == 0:
return None
return int(positions[0][0])
def unpad_text_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def unpad_audio_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def encode_audio(self, wav: torch.Tensor):
"""
Encode raw waveform into latent features using AudioVAE.
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
We then transpose to [B, T, D] to match downstream expectations.
"""
wav = wav.unsqueeze(0) # [1, T]
wav = wav.unsqueeze(1) # [1, 1, T]
wav_len = wav.size(-1)
if wav_len % self.patch_len != 0:
padding_size = self.patch_len - wav_len % self.patch_len
wav = torch.nn.functional.pad(wav, (0, padding_size))
with torch.no_grad():
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
feat = z.transpose(1, 2) # [1, T', D]
return feat
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
def __call__(
self,
audio_tokens: torch.Tensor,
text_tokens: torch.Tensor,
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
text_tokens_list: List[torch.Tensor] = []
audio_feats_list: List[torch.Tensor] = []
text_mask_list: List[torch.Tensor] = []
audio_mask_list: List[torch.Tensor] = []
loss_mask_list: List[torch.Tensor] = []
labels_list: List[torch.Tensor] = []
audio_task_ids_list: List[torch.Tensor] = []
audio_dataset_ids_list: List[torch.Tensor] = []
lengths: List[int] = []
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
audio_duration_consumed[dataset_idx] += audio_duration
text_token_consumed[dataset_idx] += text_token_count
audio_task_id = torch.zeros_like(audio_mask)
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
audio_dataset_id = torch.zeros_like(audio_mask)
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
text_tokens_list.append(packed_text)
text_mask_list.append(text_mask)
audio_feats_list.append(audio_feat)
audio_mask_list.append(audio_mask)
loss_mask_list.append(loss_mask)
labels_list.append(labels)
audio_task_ids_list.append(audio_task_id)
audio_dataset_ids_list.append(audio_dataset_id)
lengths.append(packed_text.shape[0])
# Determine padded length per batch (cap by self.max_len)
if lengths:
max_len = min(self.max_len, max(lengths))
else:
max_len = self.max_len
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
for L in lengths:
L_clip = min(L, max_len)
pos = torch.arange(0, L_clip, device=device)
if L_clip < max_len:
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
pos = torch.cat([pos, pad], dim=0)
position_ids_list.append(pos)
position_ids = torch.stack(position_ids_list, dim=0)
else:
# Empty batch fallback (shouldn't really happen)
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
text_mask_batch = torch.zeros_like(text_tokens_batch)
audio_feats_batch = torch.zeros(
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
)
audio_mask_batch = torch.zeros_like(text_tokens_batch)
loss_mask_batch = torch.zeros_like(text_tokens_batch)
labels_batch = torch.zeros_like(text_tokens_batch)
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
position_ids = torch.zeros_like(text_tokens_batch)
audio_duration_consumed = audio_duration_consumed.to(torch.long)
text_token_consumed = text_token_consumed.to(torch.long)
return {
"text_tokens": text_tokens_batch,
"audio_feats": audio_feats_batch,
"text_mask": text_mask_batch,
"audio_mask": audio_mask_batch,
"loss_mask": loss_mask_batch,
"position_ids": position_ids,
"labels": labels_batch,
"audio_task_ids": audio_task_ids_batch,
"audio_dataset_ids": audio_dataset_ids_batch,
"audio_duration_consumed": audio_duration_consumed,
"text_token_consumed": text_token_consumed,
}
# ------------------------------------------------------------------ #
# Feature extraction helpers
# ------------------------------------------------------------------ #
def extract_audio_feats(self, audio_data: torch.Tensor):
audio_feats = self.encode_audio(audio_data)
if audio_feats.size(1) % self.patch_size != 0:
audio_feats_ = audio_feats.transpose(1, 2)
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
audio_feats = padding.transpose(1, 2)
audio_duration = audio_feats.size(1) / 25
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
return audio_feats, audio_duration
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
text_token_info = torch.cat(
[
text_token,
torch.tensor(
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_token_count = len(text_token)
text_length = text_token_info.shape[0]
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
audio_feat_info = audio_feat_info.squeeze(0)
audio_length = audio_feat_info.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token_info = torch.cat(
[
text_token_info,
text_pad_token,
torch.tensor(
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
dtype=torch.int32,
device=text_token.device,
),
]
)
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, audio_feat_info.size(-1)),
dtype=torch.float32,
device=text_token.device,
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class TrainingState:
"""
Container that mirrors the object returned in the minicpm-audio training
loop. It holds persistent references to the model, optimizer, scheduler,
dataloaders and tracker.
"""
generator: object
optimizer: object
scheduler: object
train_loader: object
val_loader: object
tracker: object
batch_processor: object

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import contextlib
import time
from pathlib import Path
from typing import Dict, Optional
class TrainingTracker:
"""
Lightweight tracker inspired by the minimcpm-audio training workflow.
It keeps track of the current global step, prints rank-aware messages,
optionally writes to TensorBoard via a provided writer, and stores progress
in a logfile for later inspection.
"""
def __init__(
self,
*,
writer=None,
log_file: Optional[str] = None,
rank: int = 0,
):
self.writer = writer
self.log_file = Path(log_file) if log_file else None
if self.log_file:
self.log_file.parent.mkdir(parents=True, exist_ok=True)
self.rank = rank
self.step = 0
# Record the time of the last log to calculate the interval
self._last_log_time: float | None = None
# ------------------------------------------------------------------ #
# Logging helpers
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")
def log_metrics(self, metrics: Dict[str, float], split: str):
if self.rank == 0:
now = time.time()
dt_str = ""
if self._last_log_time is not None:
dt = now - self._last_log_time
dt_str = f", log interval: {dt:.2f}s"
self._last_log_time = now
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
if self.writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{split}/{key}", value, self.step)
def done(self, split: str, message: str):
self.print(f"[{split}] {message}")
# ------------------------------------------------------------------ #
# State dict
# ------------------------------------------------------------------ #
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state):
self.step = int(state.get("step", 0))
# ------------------------------------------------------------------ #
# Context manager compatibility (for parity with minicpm-audio code)
# ------------------------------------------------------------------ #
@contextlib.contextmanager
def live(self):
yield