mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 19:58:12 +00:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81467f649f | ||
|
|
400f47a516 | ||
|
|
b1f7593ae0 | ||
|
|
6a5e713698 | ||
|
|
3443dbb212 | ||
|
|
d1bb6aaf41 | ||
|
|
2eb4d39719 | ||
|
|
fbf8984d4e | ||
|
|
41752dc0fa | ||
|
|
b0714adcaa | ||
|
|
89f4d917a0 | ||
|
|
5c5da0dbe6 | ||
|
|
961569e76d | ||
|
|
5f56d5ff5d | ||
|
|
169c17ddfd | ||
|
|
996c69a1a8 | ||
|
|
dc6b6d1d1c | ||
|
|
cef6aefb3d | ||
|
|
1a46c5d1ad | ||
|
|
5257ec3dc5 | ||
|
|
bdd516b579 | ||
|
|
11568f0776 | ||
|
|
e5bcb735f0 | ||
|
|
f26a1ea2f7 | ||
|
|
1fa9e2ca02 | ||
|
|
10f48ba330 | ||
|
|
639b2272ab | ||
|
|
7e8f754ba1 | ||
|
|
032c7fe403 | ||
|
|
5390a47862 | ||
|
|
e7012f1a94 | ||
|
|
82332cfc99 | ||
|
|
605ac2d8e4 | ||
|
|
0fa8d894d1 | ||
|
|
776c0d19fb | ||
|
|
ed6e6b4dac | ||
|
|
e3108d4a12 | ||
|
|
59fe3f30a1 | ||
|
|
6f2fb45756 | ||
|
|
91128d823d |
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
launch.json
|
||||
__pycache__
|
||||
voxcpm.egg-info
|
||||
.DS_Store
|
||||
215
README.md
215
README.md
@@ -1,14 +1,27 @@
|
||||
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
|
||||
|
||||
|
||||
[](https://github.com/OpenBMB/VoxCPM/) [](https://huggingface.co/openbmb/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://thuhcsi.github.io/VoxCPM/)
|
||||
[](https://github.com/OpenBMB/VoxCPM/) [](https://arxiv.org/abs/2509.24650)[](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://openbmb.github.io/VoxCPM-demopage)
|
||||
|
||||
#### VoxCPM1.5 Model Weights
|
||||
|
||||
[](https://huggingface.co/openbmb/VoxCPM1.5) [](https://modelscope.cn/models/OpenBMB/VoxCPM1.5)
|
||||
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%">
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
👋 Contact us on [WeChat](assets/wechat.png)
|
||||
|
||||
</div>
|
||||
|
||||
## News
|
||||
* [2025.12.05] 🎉 🎉 🎉 We Open Source the VoxCPM1.5 [weights](https://huggingface.co/openbmb/VoxCPM1.5)! The model now supports both full-parameter fine-tuning and efficient LoRA fine-tuning, empowering you to create your own tailored version. See [Release Notes](docs/release_note.md) for details.
|
||||
* [2025.09.30] 🔥 🔥 🔥 We Release VoxCPM [Technical Report](https://arxiv.org/abs/2509.24650)!
|
||||
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)!
|
||||
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
|
||||
|
||||
@@ -25,15 +38,22 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
|
||||
|
||||
### 🚀 Key Features
|
||||
- **Context-Aware, Expressive Speech Generation** - VoxCPM comprehends text to infer and generate appropriate prosody, delivering speech with remarkable expressiveness and natural flow. It spontaneously adapts speaking style based on content, producing highly fitting vocal expression trained on a massive 1.8 million-hour bilingual corpus.
|
||||
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speaker’s timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
|
||||
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speaker's timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
|
||||
- **High-Efficiency Synthesis** - VoxCPM supports streaming synthesis with a Real-Time Factor (RTF) as low as 0.17 on a consumer-grade NVIDIA RTX 4090 GPU, making it possible for real-time applications.
|
||||
|
||||
### 📦 Model Versions
|
||||
See [Release Notes](docs/release_note.md) for details
|
||||
- **VoxCPM1.5** (Latest):
|
||||
- Model Params: 750M
|
||||
- Sampling rate of AudioVAE: 44100
|
||||
- Token rate in LM Backbone: 6.25Hz (patch-size=4)
|
||||
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
- **VoxCPM-0.5B** (Original):
|
||||
- Model Params: 600M
|
||||
- Sampling rate of AudioVAE: 16000
|
||||
- Token rate in LM Backbone: 12.5Hz (patch-size=2)
|
||||
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17
|
||||
|
||||
|
||||
|
||||
@@ -45,12 +65,18 @@ pip install voxcpm
|
||||
```
|
||||
### 1. Model Download (Optional)
|
||||
By default, when you first run the script, the model will be downloaded automatically, but you can also download the model in advance.
|
||||
- Download VoxCPM-0.5B
|
||||
- Download VoxCPM1.5
|
||||
```
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
|
||||
snapshot_download("openbmb/VoxCPM1.5")
|
||||
```
|
||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||
|
||||
- Or Download VoxCPM-0.5B
|
||||
```
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("openbmb/VoxCPM-0.5B")
|
||||
```
|
||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||
```
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('iic/speech_zipenhancer_ans_multiloss_16k_base')
|
||||
@@ -60,25 +86,39 @@ By default, when you first run the script, the model will be downloaded automati
|
||||
### 2. Basic Usage
|
||||
```python
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
from voxcpm import VoxCPM
|
||||
|
||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM1.5")
|
||||
|
||||
# Non-streaming
|
||||
wav = model.generate(
|
||||
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
||||
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
||||
prompt_text=None, # optional: reference text
|
||||
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
||||
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
||||
normalize=True, # enable external TN tool
|
||||
denoise=True, # enable external Denoise tool
|
||||
normalize=False, # enable external TN tool, but will disable native raw text support
|
||||
denoise=False, # enable external Denoise tool, but it may cause some distortion and restrict the sampling rate to 16kHz
|
||||
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
||||
retry_badcase_max_times=3, # maximum retrying times
|
||||
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
|
||||
)
|
||||
|
||||
sf.write("output.wav", wav, 16000)
|
||||
sf.write("output.wav", wav, model.tts_model.sample_rate)
|
||||
print("saved: output.wav")
|
||||
|
||||
# Streaming
|
||||
chunks = []
|
||||
for chunk in model.generate_streaming(
|
||||
text = "Streaming text to speech is easy with VoxCPM!",
|
||||
# supports same args as above
|
||||
):
|
||||
chunks.append(chunk)
|
||||
wav = np.concatenate(chunks)
|
||||
|
||||
sf.write("output_streaming.wav", wav, model.tts_model.sample_rate)
|
||||
print("saved: output_streaming.wav")
|
||||
```
|
||||
|
||||
### 3. CLI Usage
|
||||
@@ -94,7 +134,14 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "reference transcript" \
|
||||
--output out.wav \
|
||||
--denoise
|
||||
# --denoise
|
||||
|
||||
# (Optinal) Voice cloning (reference audio + transcript file)
|
||||
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-file "/path/to/text-file" \
|
||||
--output out.wav \
|
||||
# --denoise
|
||||
|
||||
# 3) Batch processing (one text per line)
|
||||
voxcpm --input examples/input.txt --output-dir outs
|
||||
@@ -102,7 +149,7 @@ voxcpm --input examples/input.txt --output-dir outs
|
||||
voxcpm --input examples/input.txt --output-dir outs \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "reference transcript" \
|
||||
--denoise
|
||||
# --denoise
|
||||
|
||||
# 4) Inference parameters (quality/speed)
|
||||
voxcpm --text "..." --output out.wav \
|
||||
@@ -113,7 +160,7 @@ voxcpm --text "..." --output out.wav \
|
||||
voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir
|
||||
# Or from Hugging Face (auto download/cache)
|
||||
voxcpm --text "..." --output out.wav \
|
||||
--hf-model-id openbmb/VoxCPM-0.5B --cache-dir ~/.cache/huggingface --local-files-only
|
||||
--hf-model-id openbmb/VoxCPM1.5 --cache-dir ~/.cache/huggingface --local-files-only
|
||||
|
||||
# 6) Denoiser control
|
||||
voxcpm --text "..." --output out.wav \
|
||||
@@ -128,104 +175,51 @@ python -m voxcpm.cli --help
|
||||
|
||||
You can start the UI interface by running `python app.py`, which allows you to perform Voice Cloning and Voice Creation.
|
||||
|
||||
### 5. Fine-tuning
|
||||
|
||||
VoxCPM1.5 supports both full fine-tuning (SFT) and LoRA fine-tuning, allowing you to train personalized voice models on your own data. See the [Fine-tuning Guide](docs/finetune.md) for detailed instructions.
|
||||
|
||||
## 👩🍳 A Voice Chef's Guide
|
||||
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Let’s begin.
|
||||
**Quick Start:**
|
||||
```bash
|
||||
# Full fine-tuning
|
||||
python scripts/train_voxcpm_finetune.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
|
||||
|
||||
---
|
||||
### 🥚 Step 1: Prepare Your Base Ingredients (Content)
|
||||
# LoRA fine-tuning
|
||||
python scripts/train_voxcpm_finetune.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
|
||||
```
|
||||
|
||||
First, choose how you’d like to input your text:.
|
||||
1. Regular Text (Classic Mode)
|
||||
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
|
||||
2. Phoneme Input (Native Mode)
|
||||
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like {HH AH0 L OW1} (EN) or {ni3}{hao3} (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
|
||||
|
||||
---
|
||||
### 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
|
||||
|
||||
This is the secret sauce that gives your audio its unique sound.
|
||||
1. Cooking with a Prompt Speech (Following a Famous Recipe)
|
||||
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
|
||||
- For a Clean, Studio-Quality Voice:
|
||||
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
|
||||
2. Cooking au Naturel (Letting the Model Improvise)
|
||||
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
|
||||
- Pro Tip: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
|
||||
|
||||
---
|
||||
### 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
|
||||
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
|
||||
- CFG Value (How Closely to Follow the Recipe)
|
||||
- Default: A great starting point.
|
||||
- Voice sounds strained or weird? Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
|
||||
- Need maximum clarity and adherence to the text? Raise it slightly to keep the model on a tighter leash.
|
||||
- Inference Timesteps (Simmering Time: Quality vs. Speed)
|
||||
- Need a quick snack? Use a lower number. Perfect for fast drafts and experiments.
|
||||
- Cooking a gourmet meal? Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
|
||||
|
||||
---
|
||||
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!
|
||||
## 📚 Documentation
|
||||
|
||||
- **[Usage Guide](docs/usage_guide.md)** - Detailed guide on how to use VoxCPM effectively, including text input modes, voice cloning tips, and parameter tuning
|
||||
- **[Fine-tuning Guide](docs/finetune.md)** - Complete guide for fine-tuning VoxCPM models with SFT and LoRA
|
||||
- **[Release Notes](docs/release_note.md)** - Version history and updates
|
||||
- **[Performance Benchmarks](docs/performance.md)** - Detailed performance comparisons on public benchmarks
|
||||
|
||||
---
|
||||
|
||||
## 📚 More Information
|
||||
|
||||
### 🌟 Community Projects
|
||||
We're excited to see the VoxCPM community growing! Here are some amazing projects and features built by our community:
|
||||
- **[ComfyUI-VoxCPM](https://github.com/wildminder/ComfyUI-VoxCPM)** A VoxCPM extension for ComfyUI.
|
||||
- **[ComfyUI-VoxCPMTTS](https://github.com/1038lab/ComfyUI-VoxCPMTTS)** A VoxCPM extension for ComfyUI.
|
||||
- **[WebUI-VoxCPM](https://github.com/rsxdalv/tts_webui_extension.vox_cpm)** A template extension for TTS WebUI.
|
||||
- **[PR: Streaming API Support (by AbrahamSanders)](https://github.com/OpenBMB/VoxCPM/pull/26)**
|
||||
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
|
||||
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
|
||||
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
|
||||
|
||||
## 📊 Performance Highlights
|
||||
|
||||
VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
||||
|
||||
### Seed-TTS-eval Benchmark
|
||||
|
||||
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
||||
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
||||
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
||||
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
||||
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
||||
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
||||
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
||||
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
||||
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
||||
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
|
||||
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
||||
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
||||
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
||||
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
||||
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
|
||||
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
|
||||
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
|
||||
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
||||
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
|
||||
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
|
||||
|
||||
|
||||
### CV3-eval Benchmark
|
||||
|
||||
| Model | zh | en | hard-zh | | | hard-en | | |
|
||||
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
|
||||
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
|
||||
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
|
||||
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
|
||||
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
|
||||
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
|
||||
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
|
||||
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | - | - | - |
|
||||
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
|
||||
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
|
||||
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
|
||||
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |
|
||||
|
||||
|
||||
|
||||
|
||||
*Note: The projects are not officially maintained by OpenBMB.*
|
||||
|
||||
|
||||
|
||||
*Have you built something cool with VoxCPM? We'd love to feature it here! Please open an issue or pull request to add your project.*
|
||||
|
||||
### 📊 Performance Highlights
|
||||
|
||||
VoxCPM achieves competitive results on public zero-shot TTS benchmarks. See [Performance Benchmarks](docs/performance.md) for detailed comparison tables.
|
||||
|
||||
|
||||
|
||||
@@ -236,6 +230,16 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
||||
- Bilingual Model: VoxCPM is trained primarily on Chinese and English data. Performance on other languages is not guaranteed and may result in unpredictable or low-quality audio.
|
||||
- This model is released for research and development purposes only. We do not recommend its use in production or commercial applications without rigorous testing and safety evaluations. Please use VoxCPM responsibly.
|
||||
|
||||
---
|
||||
|
||||
## 📝 TO-DO List
|
||||
Please stay tuned for updates!
|
||||
- [x] Release the VoxCPM technical report.
|
||||
- [x] Support higher sampling rate (44.1kHz in VoxCPM-1.5).
|
||||
- [x] Support SFT and LoRA fine-tuning.
|
||||
- [] Multilingual Support (besides ZH/EN).
|
||||
- [] Controllable Speech Generation by Human Instruction.
|
||||
|
||||
|
||||
|
||||
## 📄 License
|
||||
@@ -258,6 +262,8 @@ This project is developed by the following institutions:
|
||||
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
||||
|
||||
|
||||
## ⭐ Star History
|
||||
[](https://star-history.com/#OpenBMB/VoxCPM&Date)
|
||||
|
||||
|
||||
## 📚 Citation
|
||||
@@ -265,11 +271,10 @@ This project is developed by the following institutions:
|
||||
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
|
||||
|
||||
```bib
|
||||
@misc{voxcpm2025,
|
||||
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}},
|
||||
title = {{VoxCPM}},
|
||||
@article{voxcpm2025,
|
||||
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning},
|
||||
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
|
||||
journal = {arXiv preprint arXiv:2509.24650},
|
||||
year = {2025},
|
||||
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
|
||||
note = {GitHub repository}
|
||||
}
|
||||
```
|
||||
|
||||
17
app.py
17
app.py
@@ -8,7 +8,7 @@ from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM-0.5B"
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -29,7 +29,7 @@ class VoxCPMDemo:
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM-0.5B"
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
@@ -108,7 +108,7 @@ class VoxCPMDemo:
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
)
|
||||
return (16000, wav)
|
||||
return (current_model.tts_model.sample_rate, wav)
|
||||
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
@@ -170,7 +170,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
|
||||
# Pro Tips
|
||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||
gr.Markdown(f"""
|
||||
gr.Markdown("""
|
||||
### Prompt Speech Enhancement|参考语音降噪
|
||||
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
|
||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
|
||||
@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
**调低**:合成速度更快。
|
||||
- **Higher** for better synthesis quality.
|
||||
**调高**:合成质量更佳。
|
||||
|
||||
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
|
||||
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
|
||||
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
|
||||
""")
|
||||
|
||||
# Main controls
|
||||
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
type="filepath",
|
||||
label="Prompt Speech",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use WeTextPorcessing library to normalize the input text."
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
|
||||
BIN
assets/wechat.png
Normal file
BIN
assets/wechat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.5 KiB |
21
conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
Normal file
21
conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
pretrained_path: /path/to/VoxCPM1.5/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 44100
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
learning_rate: 0.00001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
|
||||
save_path: /path/to/checkpoints/finetune_all
|
||||
tensorboard: /path/to/logs/finetune_all
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
28
conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
Normal file
28
conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
pretrained_path: /path/to/VoxCPM1.5/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 44100
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
learning_rate: 0.0001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
|
||||
save_path: /path/to/checkpoints/finetune_lora
|
||||
tensorboard: /path/to/logs/finetune_lora
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
21
conf/voxcpm_v1/voxcpm_finetune_all.yaml
Normal file
21
conf/voxcpm_v1/voxcpm_finetune_all.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
pretrained_path: /path/to/VoxCPM-0.5B/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 16000
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
learning_rate: 0.00001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
|
||||
save_path: /path/to/checkpoints/finetune_all
|
||||
tensorboard: /path/to/logs/finetune_all
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
28
conf/voxcpm_v1/voxcpm_finetune_lora.yaml
Normal file
28
conf/voxcpm_v1/voxcpm_finetune_lora.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
pretrained_path: /path/to/VoxCPM-0.5B/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 16000
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1 # Gradient accumulation steps, >1 can increase effective batch size without increasing memory
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
learning_rate: 0.0001
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192 # Example: single batch can have at most 16k tokens, with batch_size=4, each sample can have at most 4096 tokens
|
||||
save_path: /path/to/checkpoints/finetune_lora
|
||||
tensorboard: /path/to/logs/finetune_lora
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
enable_proj: false
|
||||
r: 32
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
375
docs/finetune.md
Normal file
375
docs/finetune.md
Normal file
@@ -0,0 +1,375 @@
|
||||
# VoxCPM Fine-tuning Guide
|
||||
|
||||
This guide covers how to fine-tune VoxCPM models with two approaches: full fine-tuning and LoRA fine-tuning.
|
||||
|
||||
### 🎓 SFT (Supervised Fine-Tuning)
|
||||
|
||||
Full fine-tuning updates all model parameters. Suitable for:
|
||||
- 📊 Large, specialized datasets
|
||||
- 🔄 Cases where significant behavior changes are needed
|
||||
|
||||
### ⚡ LoRA Fine-tuning
|
||||
|
||||
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
|
||||
- 🎯 Trains only a small number of additional parameters
|
||||
- 💾 Significantly reduces memory requirements and training time
|
||||
- 🔀 Supports multiple LoRA adapters with hot-swapping
|
||||
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Full Fine-tuning](#full-fine-tuning)
|
||||
- [LoRA Fine-tuning](#lora-fine-tuning)
|
||||
- [Inference](#inference)
|
||||
- [LoRA Hot-swapping](#lora-hot-swapping)
|
||||
- [FAQ](#faq)
|
||||
|
||||
---
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Training data should be prepared as a JSONL manifest file, with one sample per line:
|
||||
|
||||
```jsonl
|
||||
{"audio": "path/to/audio1.wav", "text": "Transcript of audio 1."}
|
||||
{"audio": "path/to/audio2.wav", "text": "Transcript of audio 2."}
|
||||
{"audio": "path/to/audio3.wav", "text": "Optional duration field.", "duration": 3.5}
|
||||
{"audio": "path/to/audio4.wav", "text": "Optional dataset_id for multi-dataset.", "dataset_id": 1}
|
||||
```
|
||||
|
||||
### Required Fields
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `audio` | Path to audio file (absolute or relative) |
|
||||
| `text` | Corresponding transcript |
|
||||
|
||||
### Optional Fields
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `duration` | Audio duration in seconds (speeds up sample filtering) |
|
||||
| `dataset_id` | Dataset ID for multi-dataset training (default: 0) |
|
||||
|
||||
### Requirements
|
||||
|
||||
- Audio format: WAV
|
||||
- Sample rate: 16kHz for VoxCPM-0.5B, 44.1kHz for VoxCPM1.5
|
||||
- Text: Transcript matching the audio content
|
||||
|
||||
See `examples/train_data_example.jsonl` for a complete example.
|
||||
|
||||
---
|
||||
|
||||
## Full Fine-tuning
|
||||
|
||||
Full fine-tuning updates all model parameters. Suitable for large datasets or when significant behavior changes are needed.
|
||||
|
||||
### Configuration
|
||||
|
||||
Create `conf/voxcpm_v1.5/voxcpm_finetune_all.yaml`:
|
||||
|
||||
```yaml
|
||||
pretrained_path: /path/to/VoxCPM1.5/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: ""
|
||||
|
||||
sample_rate: 44100
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
|
||||
learning_rate: 0.00001 # Use smaller LR for full fine-tuning
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192
|
||||
|
||||
save_path: /path/to/checkpoints/finetune_all
|
||||
tensorboard: /path/to/logs/finetune_all
|
||||
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
# Single GPU
|
||||
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
|
||||
|
||||
# Multi-GPU
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
|
||||
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_all.yaml
|
||||
```
|
||||
|
||||
### Checkpoint Structure
|
||||
|
||||
Full fine-tuning saves a complete model directory that can be loaded directly:
|
||||
|
||||
```
|
||||
checkpoints/finetune_all/
|
||||
└── step_0002000/
|
||||
├── model.safetensors # Model weights (excluding audio_vae)
|
||||
├── config.json # Model config
|
||||
├── audiovae.pth # Audio VAE weights
|
||||
├── tokenizer.json # Tokenizer
|
||||
├── tokenizer_config.json
|
||||
├── special_tokens_map.json
|
||||
├── optimizer.pth
|
||||
└── scheduler.pth
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LoRA Fine-tuning
|
||||
|
||||
LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that trains only a small number of additional parameters, significantly reducing memory requirements.
|
||||
|
||||
### Configuration
|
||||
|
||||
Create `conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml`:
|
||||
|
||||
```yaml
|
||||
pretrained_path: /path/to/VoxCPM1.5/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: ""
|
||||
|
||||
sample_rate: 44100
|
||||
batch_size: 16
|
||||
grad_accum_steps: 1
|
||||
num_workers: 2
|
||||
num_iters: 2000
|
||||
log_interval: 10
|
||||
valid_interval: 1000
|
||||
save_interval: 1000
|
||||
|
||||
learning_rate: 0.0001 # LoRA can use larger LR
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 100
|
||||
max_steps: 2000
|
||||
max_batch_tokens: 8192
|
||||
|
||||
save_path: /path/to/checkpoints/finetune_lora
|
||||
tensorboard: /path/to/logs/finetune_lora
|
||||
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
|
||||
# LoRA configuration
|
||||
lora:
|
||||
enable_lm: true # Apply LoRA to Language Model
|
||||
enable_dit: true # Apply LoRA to Diffusion Transformer
|
||||
enable_proj: false # Apply LoRA to projection layers (optional)
|
||||
|
||||
r: 32 # LoRA rank (higher = more capacity)
|
||||
alpha: 16 # LoRA alpha, scaling = alpha / r
|
||||
dropout: 0.0
|
||||
|
||||
# Target modules
|
||||
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
```
|
||||
|
||||
### LoRA Parameters
|
||||
|
||||
| Parameter | Description | Recommended |
|
||||
|-----------|-------------|-------------|
|
||||
| `enable_lm` | Apply LoRA to LM (language model) | `true` |
|
||||
| `enable_dit` | Apply LoRA to DiT (diffusion model) | `true` (required for voice cloning) |
|
||||
| `r` | LoRA rank (higher = more capacity) | 16-64 |
|
||||
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
|
||||
| `target_modules_*` | Layer names to add LoRA | attention layers |
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
# Single GPU
|
||||
python scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
|
||||
|
||||
# Multi-GPU
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
|
||||
scripts/train_voxcpm_finetune.py --config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml
|
||||
```
|
||||
|
||||
### Checkpoint Structure
|
||||
|
||||
LoRA training saves only LoRA parameters:
|
||||
|
||||
```
|
||||
checkpoints/finetune_lora/
|
||||
└── step_0002000/
|
||||
├── lora_weights.safetensors # Only lora_A, lora_B parameters
|
||||
├── optimizer.pth
|
||||
└── scheduler.pth
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Inference
|
||||
|
||||
### Full Fine-tuning Inference
|
||||
|
||||
The checkpoint directory is a complete model, load it directly:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
|
||||
--text "Hello, this is the fine-tuned model." \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
With voice cloning:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/finetune_all/step_0002000 \
|
||||
--text "This is voice cloning result." \
|
||||
--prompt_audio /path/to/reference.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output cloned_output.wav
|
||||
```
|
||||
|
||||
### LoRA Inference
|
||||
|
||||
LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--text "Hello, this is LoRA fine-tuned result." \
|
||||
--output lora_output.wav
|
||||
```
|
||||
|
||||
With voice cloning:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--text "This is voice cloning with LoRA." \
|
||||
--prompt_audio /path/to/reference.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output cloned_output.wav
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LoRA Hot-swapping
|
||||
|
||||
LoRA supports dynamic loading, unloading, and switching at inference time without reloading the entire model.
|
||||
|
||||
### API Reference
|
||||
|
||||
```python
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
|
||||
# 1. Load model with LoRA structure and weights
|
||||
lora_cfg = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
r=32,
|
||||
alpha=16,
|
||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||
)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5", # or local path
|
||||
load_denoiser=False, # Optional: disable denoiser for faster loading
|
||||
optimize=True, # Enable torch.compile acceleration
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path="/path/to/lora_checkpoint",
|
||||
)
|
||||
|
||||
# 2. Generate audio
|
||||
audio = model.generate(
|
||||
text="Hello, this is LoRA fine-tuned result.",
|
||||
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
|
||||
prompt_text="Reference audio transcript", # Optional: for voice cloning
|
||||
)
|
||||
|
||||
# 3. Disable LoRA (use base model only)
|
||||
model.set_lora_enabled(False)
|
||||
|
||||
# 4. Re-enable LoRA
|
||||
model.set_lora_enabled(True)
|
||||
|
||||
# 5. Unload LoRA (reset weights to zero)
|
||||
model.unload_lora()
|
||||
|
||||
# 6. Hot-swap to another LoRA
|
||||
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
|
||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
|
||||
# 7. Get current LoRA weights
|
||||
lora_state = model.get_lora_state_dict()
|
||||
```
|
||||
|
||||
### Simplified Usage (Auto LoRA Config)
|
||||
|
||||
If you only have LoRA weights and don't need custom config, just provide the path:
|
||||
|
||||
```python
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5",
|
||||
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
||||
)
|
||||
```
|
||||
|
||||
### Method Reference
|
||||
|
||||
| Method | Description | torch.compile Compatible |
|
||||
|--------|-------------|--------------------------|
|
||||
| `load_lora(path)` | Load LoRA weights from file | ✅ |
|
||||
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
|
||||
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
|
||||
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
|
||||
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
|
||||
|
||||
---
|
||||
|
||||
## FAQ
|
||||
|
||||
### 1. Out of Memory (OOM)
|
||||
|
||||
- Increase `grad_accum_steps` (gradient accumulation)
|
||||
- Decrease `batch_size`
|
||||
- Use LoRA fine-tuning instead of full fine-tuning
|
||||
- Decrease `max_batch_tokens` to filter long samples
|
||||
|
||||
### 2. Poor LoRA Performance
|
||||
|
||||
- Increase `r` (LoRA rank)
|
||||
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
|
||||
- Increase training steps
|
||||
- Add more target modules
|
||||
|
||||
### 3. Training Not Converging
|
||||
|
||||
- Decrease `learning_rate`
|
||||
- Increase `warmup_steps`
|
||||
- Check data quality
|
||||
|
||||
### 4. LoRA Not Taking Effect at Inference
|
||||
|
||||
- Ensure inference config matches training config LoRA parameters
|
||||
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||
- Verify `set_lora_enabled(True)` is called
|
||||
|
||||
### 5. Checkpoint Loading Errors
|
||||
|
||||
- Full fine-tuning: checkpoint directory should contain `model.safetensors`(or `pytorch_model.bin`), `config.json`, `audiovae.pth`
|
||||
- LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`)
|
||||
46
docs/performance.md
Normal file
46
docs/performance.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# 📊 Performance Highlights
|
||||
|
||||
VoxCPM achieves competitive results on public zero-shot TTS benchmarks.
|
||||
|
||||
## Seed-TTS-eval Benchmark
|
||||
|
||||
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
||||
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
||||
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
||||
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
||||
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
||||
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
||||
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
||||
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
||||
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
||||
| MaskGCT | 1B | ✅ | 2.62 | 71.7 | 2.27 | 77.4 | - | - |
|
||||
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
||||
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
|
||||
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
||||
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
||||
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
||||
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
|
||||
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | 23.37 | 64.3 |
|
||||
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | 7.12 | 75.5 |
|
||||
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
||||
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | 55.07 | 65.6 |
|
||||
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
|
||||
|
||||
|
||||
## CV3-eval Benchmark
|
||||
|
||||
| Model | zh | en | hard-zh | | | hard-en | | |
|
||||
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
|
||||
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
|
||||
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
|
||||
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
|
||||
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
|
||||
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
|
||||
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
|
||||
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | 8.78 | 74.5 | 3.80 |
|
||||
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
|
||||
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
|
||||
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
|
||||
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |
|
||||
|
||||
109
docs/release_note.md
Normal file
109
docs/release_note.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# VoxCPM1.5 Release Notes
|
||||
|
||||
**Release Date:** December 5, 2025
|
||||
|
||||
## 🎉 Overview
|
||||
|
||||
|
||||
We’re thrilled to introduce a major upgrade that improves audio quality and efficiency of VoxCPM, while maintaining the core capabilities of context-aware speech generation and zero-shot voice cloning.
|
||||
|
||||
| Feature | VoxCPM | VoxCPM1.5 |
|
||||
|---------|------------|------------|
|
||||
| **Audio VAE Sampling Rate** | 16kHz | 44.1kHz |
|
||||
| **LM Token Rate** | 12.5Hz | 6.25Hz |
|
||||
| **Patch Size** | 2 | 4 |
|
||||
| **SFT Support** | ✅ | ✅ |
|
||||
| **LoRA Support** | ✅ | ✅ |
|
||||
|
||||
## 🎵 Model Updates
|
||||
|
||||
### 🔊 AudioVAE Sampling Rate: 16kHz → 44.1kHz
|
||||
|
||||
The AudioVAE now supports 44.1kHz sampling rate, which allows the model to:
|
||||
- 🎯 Clone better, preserving more high-frequency details and generate higher quality voice outputs
|
||||
|
||||
|
||||
*Note: This upgrade enables higher quality generation when using high-quality reference audio, but does not guarantee that all generated audio will be high-fidelity. The output quality depends on the **prompt speech** quality.*
|
||||
|
||||
### ⚡ Token Rate: 12.5Hz → 6.25Hz
|
||||
|
||||
We reduced the token rate in LM backbone from 12.5Hz to 6.25Hz (LocEnc&LocDiT patch size increased from 2 to 4) while maintaining similar performance on evaluation benchmarks. This change:
|
||||
- 💨 Reduces computational requirements for generating the same length of audio
|
||||
- 📈 Provides a foundation for longer audio generation
|
||||
- 🏗️ Paves the way for training larger models in the future
|
||||
|
||||
|
||||
## 🔧 Fine-tuning Support
|
||||
|
||||
We support full fine-tuning and LoRA fine-tuning now, please see the [Fine-tuning Guide](finetune.md) for detailed instructions.
|
||||
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- Updated README with version comparison
|
||||
- Added comprehensive fine-tuning guide
|
||||
- Improved code comments and documentation
|
||||
|
||||
|
||||
## 🙏 Our Thanks to You
|
||||
This release wouldn’t be possible without the incredible feedback, testing, and contributions from our open-source community. Thank you for helping shape VoxCPM1.5!
|
||||
|
||||
|
||||
## 📞 Let's Build Together
|
||||
Questions, ideas, or want to contribute?
|
||||
|
||||
- 🐛 Report an issue: [GitHub Issues on OpenBMB/VoxCPM](https://github.com/OpenBMB/VoxCPM/issues)
|
||||
|
||||
- 📖 Dig into the docs: Check the [docs/](../docs/) folder for guides and API details
|
||||
|
||||
Enjoy the richer sound and powerful new features of VoxCPM1.5 🎉
|
||||
|
||||
We can't wait to hear what you create next! 🥂
|
||||
|
||||
## 🚀 What We're Working On
|
||||
|
||||
We're continuously improving VoxCPM and working on exciting new features:
|
||||
|
||||
- 🌍 **Multilingual TTS Support**: We are actively developing support for languages beyond Chinese and English.
|
||||
- 🎯 **Controllable Expressive Speech Generation**: We are researching controllable speech generation that allows fine-grained control over speech attributes (emotion, timbre, prosody, etc.) through natural language instructions.
|
||||
- 🎵 **Universal Audio Generation Foundation**: We also hope to explore VoxCPM as a unified audio generation foundation model capable of joint generation of speech, music, and sound effects. However, this is a longer-term vision.
|
||||
|
||||
**📅 Next Release**: We plan to release the next version in Q1 2026, which will include significant improvements and new features. Stay tuned for updates! We're committed to making VoxCPM even more powerful and versatile.
|
||||
|
||||
## ❓ Frequently Asked Questions (FAQ)
|
||||
|
||||
### Q: Does VoxCPM support fine-tuning for personalized voice customization?
|
||||
|
||||
**A:** Yes! VoxCPM now supports both full fine-tuning (SFT) and efficient LoRA fine-tuning. You can train personalized voice models on your own data. Please refer to the [Fine-tuning Guide](finetune.md) for detailed instructions and examples.
|
||||
|
||||
### Q: Is 16kHz audio quality sufficient for my use case?
|
||||
|
||||
**A:** We have upgraded the AudioVAE to support 44.1kHz sampling rate in VoxCPM1.5, which provides higher quality audio output with better preservation of high-frequency details. This upgrade enables better voice cloning quality and more natural speech synthesis when using high-quality reference audio.
|
||||
|
||||
### Q: Has the stability issue been resolved?
|
||||
|
||||
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the training data and model architecture. Based on community feedback, we collected some stability issues such as:
|
||||
- Increased noise and reverberation
|
||||
- Audio artifacts (e.g., howling/squealing)
|
||||
- Unstable speaking rate (speeding up)
|
||||
- Volume fluctuations (increases or decreases)
|
||||
- Noise artifacts at the beginning and end of audio
|
||||
- Synthesis issues with very short texts (e.g., "hello")
|
||||
|
||||
While we have made improvements to these issues, they have not been completely resolved and may still occasionally occur, especially with very long or highly expressive inputs. We continue to work on further stability improvements in future versions.
|
||||
|
||||
### Q: Does VoxCPM plan to support multilingual TTS?
|
||||
|
||||
**A:** Currently, VoxCPM is primarily trained on Chinese and English data. We are actively researching and developing multilingual TTS support for more languages beyond Chinese and English. Please let us know what languages you'd like to see supported!
|
||||
|
||||
### Q: Does VoxCPM plan to support controllable generation (emotion, style, fine-grained control)?
|
||||
|
||||
**A:** Currently, VoxCPM only supports zero-shot voice cloning and context-aware speech generation. Direct control over specific speech attributes (emotion, style, fine-grained prosody) is limited. However, we are actively researching instruction-controllable expressive speech generation with fine-grained control capabilities, working towards a human instruction-to-speech generation model!
|
||||
|
||||
### Q: Does VoxCPM support different hardware chips (e.g., Ascend 910B, XPU, NPU)?
|
||||
|
||||
**A:** Currently, we have not yet adapted VoxCPM for different hardware chips. Our main focus remains on developing new model capabilities and improving stability. We encourage you to check if community developers have done similar work, and we warmly welcome everyone to contribute and promote such adaptations together!
|
||||
|
||||
These features are under active development, and we look forward to sharing updates in future releases!
|
||||
|
||||
|
||||
53
docs/usage_guide.md
Normal file
53
docs/usage_guide.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# 👩🍳 A Voice Chef's Guide
|
||||
|
||||
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Let's begin.
|
||||
|
||||
---
|
||||
|
||||
## 🥚 Step 1: Prepare Your Base Ingredients (Content)
|
||||
|
||||
First, choose how you'd like to input your text:
|
||||
|
||||
### 1. Regular Text (Classic Mode)
|
||||
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
|
||||
|
||||
### 2. Phoneme Input (Native Mode)
|
||||
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like `{HH AH0 L OW1}` (EN) or `{ni3}{hao3}` (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
|
||||
- **Phoneme Conversion**: For Chinese, phonemes are converted using pinyin. For English, phonemes are converted using CMUDict. Please refer to the relevant documentation for more details.
|
||||
|
||||
---
|
||||
|
||||
## 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
|
||||
|
||||
This is the secret sauce that gives your audio its unique sound.
|
||||
|
||||
### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
|
||||
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
|
||||
- **For a Clean, Studio-Quality Voice:**
|
||||
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
|
||||
|
||||
### 2. Cooking au Naturel (Letting the Model Improvise)
|
||||
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
|
||||
- **Pro Tip**: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
|
||||
|
||||
---
|
||||
|
||||
## 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
|
||||
|
||||
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
|
||||
|
||||
### CFG Value (How Closely to Follow the Recipe)
|
||||
- **Default**: A great starting point.
|
||||
- **Voice sounds strained or weird?** Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
|
||||
- **Need maximum clarity and adherence to the text?** Raise it slightly to keep the model on a tighter leash.
|
||||
- **Short sentences?** Consider increasing the CFG value for better clarity and adherence.
|
||||
- **Long texts?** Consider lowering the CFG value to improve stability and naturalness over extended passages.
|
||||
|
||||
### Inference Timesteps (Simmering Time: Quality vs. Speed)
|
||||
- **Need a quick snack?** Use a lower number. Perfect for fast drafts and experiments.
|
||||
- **Cooking a gourmet meal?** Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
|
||||
|
||||
---
|
||||
|
||||
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!
|
||||
|
||||
6
examples/train_data_example.jsonl
Normal file
6
examples/train_data_example.jsonl
Normal file
@@ -0,0 +1,6 @@
|
||||
{"audio": "examples/example.wav", "text": "This is an example audio transcript for training."}
|
||||
{"audio": "/absolute/path/to/audio1.wav", "text": "You can use absolute paths for audio files."}
|
||||
{"audio": "relative/path/to/audio2.wav", "text": "Or relative paths from the working directory."}
|
||||
{"audio": "data/audio3.wav", "text": "Each line is a JSON object with audio path and text.", "duration": 3.5}
|
||||
{"audio": "data/audio4.wav", "text": "Optional: add duration field to skip audio loading during filtering.", "duration": 2.8}
|
||||
{"audio": "data/audio5.wav", "text": "Optional: add dataset_id for multi-dataset training.", "dataset_id": 1}
|
||||
@@ -20,23 +20,21 @@ classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"torch>=2.5.0",
|
||||
"torchaudio>=2.5.0",
|
||||
"transformers>=4.36.2",
|
||||
"einops",
|
||||
"gradio",
|
||||
"gradio<6",
|
||||
"inflect",
|
||||
"addict",
|
||||
"WeTextProcessing",
|
||||
"wetext",
|
||||
"modelscope>=1.22.0",
|
||||
"datasets>=2,<4",
|
||||
"datasets>=3,<4",
|
||||
"huggingface-hub",
|
||||
"pydantic",
|
||||
"tqdm",
|
||||
@@ -44,7 +42,10 @@ dependencies = [
|
||||
"sortedcontainers",
|
||||
"soundfile",
|
||||
"funasr",
|
||||
"spaces"
|
||||
"spaces",
|
||||
"argbind",
|
||||
"safetensors"
|
||||
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@@ -78,7 +79,7 @@ version_scheme = "post-release"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py38']
|
||||
target-version = ['py310']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
|
||||
131
scripts/test_voxcpm_ft_infer.py
Normal file
131
scripts/test_voxcpm_ft_infer.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Full finetune inference script (no LoRA).
|
||||
|
||||
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
||||
can be loaded directly via VoxCPM.
|
||||
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/step_0001000 \
|
||||
--text "Hello, I am the finetuned VoxCPM." \
|
||||
--output ft_test.wav
|
||||
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/step_0001000 \
|
||||
--text "Hello, this is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output ft_clone.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Checkpoint directory (contains pytorch_model.bin, config.json, audiovae.pth, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target text to synthesize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_audio",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: reference audio path for voice cloning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: transcript of reference audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="ft_test.wav",
|
||||
help="Output wav file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_value",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="CFG scale (default: 2.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference_timesteps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Diffusion inference steps (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory (no denoiser)
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=args.ckpt_dir,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
|
||||
# Save audio
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
225
scripts/test_voxcpm_lora_infer.py
Normal file
225
scripts/test_voxcpm_lora_infer.py
Normal file
@@ -0,0 +1,225 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LoRA inference test script.
|
||||
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "Hello, this is LoRA finetuned result." \
|
||||
--output lora_test.wav
|
||||
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "This is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output lora_clone.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training YAML config path (contains pretrained_path and lora config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_ckpt",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target text to synthesize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_audio",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: reference audio path for voice cloning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: transcript of reference audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="lora_test.wav",
|
||||
help="Output wav file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_value",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="CFG scale (default: 2.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference_timesteps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Diffusion inference steps (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# 1. Load YAML config
|
||||
cfg = load_yaml_config(args.config_path)
|
||||
pretrained_path = cfg["pretrained_path"]
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Check LoRA checkpoint
|
||||
ckpt_dir = args.lora_ckpt
|
||||
if not Path(ckpt_dir).exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=ckpt_dir,
|
||||
)
|
||||
|
||||
# 4. Synthesize audio
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...")
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
print(f" - lora_disabled: {disabled_output}")
|
||||
print(f" - lora_reenabled: {reenabled_output}")
|
||||
print(f" - lora_reset: {reset_output}")
|
||||
print(f" - lora_reloaded: {reload_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
362
scripts/train_voxcpm_finetune.py
Normal file
362
scripts/train_voxcpm_finetune.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
import contextlib
|
||||
from typing import Dict, Optional
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.optim import AdamW
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format")
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
BatchProcessor,
|
||||
TrainingTracker,
|
||||
build_dataloader,
|
||||
load_audio_text_datasets,
|
||||
)
|
||||
|
||||
|
||||
@argbind.bind(without_prefix=True)
|
||||
def train(
|
||||
pretrained_path: str,
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
sample_rate: int = 16_000,
|
||||
batch_size: int = 1,
|
||||
grad_accum_steps: int = 1,
|
||||
num_workers: int = 2,
|
||||
num_iters: int = 100_000,
|
||||
log_interval: int = 100,
|
||||
valid_interval: int = 1_000,
|
||||
save_interval: int = 10_000,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-2,
|
||||
warmup_steps: int = 1_000,
|
||||
max_steps: int = 100_000,
|
||||
max_batch_tokens: int = 0,
|
||||
save_path: str = "checkpoints",
|
||||
tensorboard: str = "",
|
||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
):
|
||||
_ = config_path
|
||||
accelerator = Accelerator(amp=True)
|
||||
|
||||
save_dir = Path(save_path)
|
||||
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs"
|
||||
|
||||
# Only create directories on rank 0 to avoid race conditions
|
||||
if accelerator.rank == 0:
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tb_dir.mkdir(parents=True, exist_ok=True)
|
||||
accelerator.barrier() # Wait for directory creation
|
||||
|
||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||
|
||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
train_manifest=train_manifest,
|
||||
val_manifest=val_manifest,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
def tokenize(batch):
|
||||
text_list = batch["text"]
|
||||
text_ids = [tokenizer(text) for text in text_list]
|
||||
return {"text_ids": text_ids}
|
||||
|
||||
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
if val_ds is not None:
|
||||
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
|
||||
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
|
||||
num_train_samples = len(train_ds)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Optional: filter samples by estimated token count to avoid OOM
|
||||
# Enabled when max_batch_tokens > 0:
|
||||
# max_sample_len = max_batch_tokens // batch_size
|
||||
# Samples exceeding this length will be dropped
|
||||
# ------------------------------------------------------------------ #
|
||||
if max_batch_tokens and max_batch_tokens > 0:
|
||||
from voxcpm.training.data import compute_sample_lengths
|
||||
|
||||
audio_vae_fps = base_model.audio_vae.sample_rate / base_model.audio_vae.hop_length
|
||||
est_lengths = compute_sample_lengths(
|
||||
train_ds,
|
||||
audio_vae_fps=audio_vae_fps,
|
||||
patch_size=base_model.config.patch_size,
|
||||
)
|
||||
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths)
|
||||
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len]
|
||||
|
||||
if len(keep_indices) < len(train_ds) and accelerator.rank == 0:
|
||||
tracker.print(
|
||||
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} "
|
||||
f"training samples longer than {max_sample_len} tokens "
|
||||
f"(max_batch_tokens={max_batch_tokens})."
|
||||
)
|
||||
train_ds = train_ds.select(keep_indices)
|
||||
|
||||
train_loader = build_dataloader(
|
||||
train_ds,
|
||||
accelerator=accelerator,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
)
|
||||
val_loader = (
|
||||
build_dataloader(
|
||||
val_ds,
|
||||
accelerator=accelerator,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
if val_ds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
batch_processor = BatchProcessor(
|
||||
config=base_model.config,
|
||||
audio_vae=base_model.audio_vae,
|
||||
dataset_cnt=dataset_cnt,
|
||||
device=accelerator.device,
|
||||
)
|
||||
del base_model.audio_vae
|
||||
model = accelerator.prepare_model(base_model)
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
unwrapped_model.train()
|
||||
|
||||
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.requires_grad)
|
||||
|
||||
optimizer = AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
# Cosine + warmup scheduler from transformers:
|
||||
# - num_warmup_steps: warmup steps
|
||||
# - num_training_steps: total training steps (outer step count)
|
||||
total_training_steps = max_steps if max_steps > 0 else num_iters
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_training_steps,
|
||||
)
|
||||
|
||||
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
|
||||
grad_accum_steps = max(int(grad_accum_steps), 1)
|
||||
data_epoch = 0
|
||||
train_iter = iter(train_loader)
|
||||
|
||||
def get_next_batch():
|
||||
"""Get next batch, handles epoch boundary and DistributedSampler."""
|
||||
nonlocal train_iter, data_epoch
|
||||
try:
|
||||
return next(train_iter)
|
||||
except StopIteration:
|
||||
data_epoch += 1
|
||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||
sampler = getattr(train_loader, 'sampler', None)
|
||||
if hasattr(sampler, 'set_epoch'):
|
||||
sampler.set_epoch(data_epoch)
|
||||
train_iter = iter(train_loader)
|
||||
return next(train_iter)
|
||||
|
||||
with tracker.live():
|
||||
for step in range(num_iters):
|
||||
tracker.step = step
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Gradient accumulation: accumulate gradients over micro-batches before optimizer step
|
||||
loss_dict = {}
|
||||
for micro_step in range(grad_accum_steps):
|
||||
batch = get_next_batch()
|
||||
processed = batch_processor(batch)
|
||||
|
||||
# Only sync gradients on the last micro-batch
|
||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||
|
||||
with sync_context:
|
||||
with accelerator.autocast(dtype=torch.bfloat16):
|
||||
outputs = model(
|
||||
processed["text_tokens"],
|
||||
processed["text_mask"],
|
||||
processed["audio_feats"],
|
||||
processed["audio_mask"],
|
||||
processed["loss_mask"],
|
||||
processed["position_ids"],
|
||||
processed["labels"],
|
||||
progress=step / max(1, num_iters),
|
||||
)
|
||||
|
||||
total_loss = 0.0
|
||||
for key, value in outputs.items():
|
||||
if key.startswith("loss/"):
|
||||
weight = lambdas.get(key, 1.0)
|
||||
loss_value = value * weight / grad_accum_steps
|
||||
total_loss = total_loss + loss_value
|
||||
# Record raw loss from last micro-batch for logging
|
||||
loss_dict[key] = value.detach()
|
||||
|
||||
# Accumulate gradients (normalized by grad_accum_steps)
|
||||
accelerator.backward(total_loss)
|
||||
|
||||
# After all micro-batches, do unscale / grad_norm / step
|
||||
scaler = getattr(accelerator, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# Use large max_norm to only compute grad_norm without actual clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
||||
|
||||
accelerator.step(optimizer)
|
||||
accelerator.update()
|
||||
scheduler.step()
|
||||
|
||||
if step % log_interval == 0:
|
||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
||||
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
|
||||
loss_values["epoch"] = float(epoch)
|
||||
loss_values["grad_norm"] = float(grad_norm)
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and step % valid_interval == 0 and step != 0:
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
|
||||
|
||||
if step % save_interval == 0 and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
|
||||
|
||||
if accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
||||
model.eval()
|
||||
losses = []
|
||||
num_batches = 0
|
||||
max_val_batches = 10
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if num_batches >= max_val_batches:
|
||||
break
|
||||
processed = batch_processor(batch)
|
||||
with accelerator.autocast(dtype=torch.bfloat16):
|
||||
outputs = model(
|
||||
processed["text_tokens"],
|
||||
processed["text_mask"],
|
||||
processed["audio_feats"],
|
||||
processed["audio_mask"],
|
||||
processed["loss_mask"],
|
||||
processed["position_ids"],
|
||||
processed["labels"],
|
||||
progress=0.0,
|
||||
sample_generate=False,
|
||||
)
|
||||
total = 0.0
|
||||
for key, value in outputs.items():
|
||||
if key.startswith("loss/"):
|
||||
total += lambdas.get(key, 1.0) * value
|
||||
losses.append(total.detach())
|
||||
num_batches += 1
|
||||
|
||||
if losses:
|
||||
mean_loss = torch.stack(losses).mean()
|
||||
# All-reduce validation loss across processes for global average
|
||||
accelerator.all_reduce(mean_loss)
|
||||
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
|
||||
model.train()
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import shutil
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = "latest" if step == 0 else f"step_{step:07d}"
|
||||
folder = save_dir / tag
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
full_state = unwrapped.state_dict()
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
if lora_cfg is not None:
|
||||
# LoRA finetune: save only lora_A/lora_B weights
|
||||
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
|
||||
if SAFETENSORS_AVAILABLE:
|
||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||
else:
|
||||
# Full finetune: save non-vae weights to model.safetensors
|
||||
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
|
||||
if SAFETENSORS_AVAILABLE:
|
||||
save_file(state_dict, folder / "model.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
|
||||
|
||||
# Copy config files from pretrained path
|
||||
if pretrained_path:
|
||||
pretrained_dir = Path(pretrained_path)
|
||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
||||
for fname in files_to_copy:
|
||||
src = pretrained_dir / fname
|
||||
if src.exists():
|
||||
shutil.copy2(src, folder / fname)
|
||||
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
args = argbind.parse_args()
|
||||
config_file = args.get("config_path")
|
||||
# If YAML config provided, use YAML args to call train
|
||||
if config_file:
|
||||
yaml_args = load_yaml_config(config_file)
|
||||
train(**yaml_args)
|
||||
else:
|
||||
# Otherwise use command line args (parsed by argbind)
|
||||
with argbind.scope(args):
|
||||
train()
|
||||
|
||||
@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Build LoRA config if lora_path is provided
|
||||
lora_config = None
|
||||
lora_weights_path = getattr(args, "lora_path", None)
|
||||
if lora_weights_path:
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
||||
r=getattr(args, "lora_r", 32),
|
||||
alpha=getattr(args, "lora_alpha", 16),
|
||||
dropout=getattr(args, "lora_dropout", 0.0),
|
||||
)
|
||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
@@ -69,11 +87,13 @@ def load_model(args) -> VoxCPM:
|
||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||
try:
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
|
||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
@@ -120,11 +140,11 @@ def cmd_clone(args):
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
@@ -152,11 +172,11 @@ def cmd_synthesize(args):
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
@@ -198,9 +218,9 @@ def cmd_batch(args):
|
||||
denoise=args.denoise and prompt_audio_path is not None
|
||||
)
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
sf.write(str(output_file), audio_array, 16000)
|
||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
||||
success_count += 1
|
||||
|
||||
@@ -240,6 +260,7 @@ Examples:
|
||||
# Prompt audio (for voice cloning)
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||
|
||||
# Generation parameters
|
||||
@@ -249,12 +270,21 @@ Examples:
|
||||
|
||||
# Model loading parameters
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
# LoRA parameters
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -279,6 +309,12 @@ def main():
|
||||
|
||||
# If prompt audio+text provided → voice cloning
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
if not args.prompt_text and args.prompt_file:
|
||||
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
||||
|
||||
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
||||
args.prompt_text = f.read()
|
||||
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
sys.exit(1)
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import numpy as np
|
||||
from typing import Generator, Optional
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
@@ -22,10 +24,32 @@ class VoxCPM:
|
||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
||||
self.text_normalizer = TextNormalizer()
|
||||
|
||||
# If lora_weights_path is provided but no lora_config, create a default one
|
||||
if lora_weights_path is not None and lora_config is None:
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=True,
|
||||
enable_dit=True,
|
||||
enable_proj=False,
|
||||
)
|
||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
||||
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
|
||||
# Load LoRA weights if path is provided
|
||||
if lora_weights_path is not None:
|
||||
print(f"Loading LoRA weights from: {lora_weights_path}")
|
||||
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
||||
|
||||
self.text_normalizer = None
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
from .zipenhancer import ZipEnhancer
|
||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||
@@ -33,27 +57,41 @@ class VoxCPM:
|
||||
self.denoiser = None
|
||||
print("Warm up VoxCPMModel...")
|
||||
self.tts_model.generate(
|
||||
target_text="Hello, this is the first test sentence."
|
||||
)
|
||||
target_text="Hello, this is the first test sentence.",
|
||||
max_len=10,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM-0.5B",
|
||||
hf_model_id: str = "openbmb/VoxCPM1.5",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
Args:
|
||||
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
|
||||
load_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
||||
acoustic noise suppression.
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
|
||||
after model initialization.
|
||||
Kwargs:
|
||||
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||
|
||||
Returns:
|
||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||
@@ -82,21 +120,33 @@ class VoxCPM:
|
||||
voxcpm_model_path=local_path,
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def generate(self,
|
||||
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
def _generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
cfg_value : float = 2.0,
|
||||
inference_timesteps : int = 10,
|
||||
max_length : int = 4096,
|
||||
normalize : bool = True,
|
||||
denoise : bool = True,
|
||||
min_len : int = 2,
|
||||
max_len : int = 4096,
|
||||
normalize : bool = False,
|
||||
denoise : bool = False,
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[np.ndarray, None, None]:
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
@@ -111,20 +161,32 @@ class VoxCPM:
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
max_length: Maximum token length during generation.
|
||||
max_len: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
streaming: Whether to return a generator of audio chunks.
|
||||
Returns:
|
||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
Yields audio chunks for each generations step if ``streaming=True``,
|
||||
otherwise yields a single array containing the final audio.
|
||||
"""
|
||||
texts = text.split("\n")
|
||||
texts = [t.strip() for t in texts if t.strip()]
|
||||
final_wav = []
|
||||
temp_prompt_wav_path = None
|
||||
if not text.strip() or not isinstance(text, str):
|
||||
raise ValueError("target text must be a non-empty string")
|
||||
|
||||
if prompt_wav_path is not None:
|
||||
if not os.path.exists(prompt_wav_path):
|
||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||
|
||||
if (prompt_wav_path is None) != (prompt_text is None):
|
||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||
|
||||
text = text.replace("\n", " ")
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
@@ -140,36 +202,79 @@ class VoxCPM:
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
for sub_text in texts:
|
||||
if sub_text.strip() == "":
|
||||
continue
|
||||
print("sub_text:", sub_text)
|
||||
if normalize:
|
||||
sub_text = self.text_normalizer.normalize(sub_text)
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
target_text=sub_text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
max_len=max_length,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
if fixed_prompt_cache is None:
|
||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
||||
original_cache=None,
|
||||
new_text_token=target_text_token,
|
||||
new_audio_feat=generated_audio_feat
|
||||
)
|
||||
final_wav.append(wav)
|
||||
if normalize:
|
||||
if self.text_normalizer is None:
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
self.text_normalizer = TextNormalizer()
|
||||
text = self.text_normalizer.normalize(text)
|
||||
|
||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||
target_text=text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=min_len,
|
||||
max_len=max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
||||
for wav, _, _ in generate_result:
|
||||
yield wav.squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Interface (delegated to VoxCPMModel)
|
||||
# ------------------------------------------------------------------ #
|
||||
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||
"""Load LoRA weights from a checkpoint file.
|
||||
|
||||
Args:
|
||||
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||
containing lora_weights.ckpt).
|
||||
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model was not initialized with LoRA config.
|
||||
"""
|
||||
if self.tts_model.lora_config is None:
|
||||
raise RuntimeError(
|
||||
"Cannot load LoRA weights: model was not initialized with LoRA config. "
|
||||
"Please reinitialize with lora_config or lora_weights_path parameter."
|
||||
)
|
||||
return self.tts_model.load_lora_weights(lora_weights_path)
|
||||
|
||||
def unload_lora(self):
|
||||
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||
self.tts_model.reset_lora_weights()
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable or disable LoRA layers without unloading weights.
|
||||
|
||||
Args:
|
||||
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||
"""
|
||||
self.tts_model.set_lora_enabled(enabled)
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get current LoRA parameters state dict.
|
||||
|
||||
Returns:
|
||||
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||
"""
|
||||
return self.tts_model.get_lora_state_dict()
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
"""Check if LoRA is currently configured."""
|
||||
return self.tts_model.lora_config is not None
|
||||
@@ -19,18 +19,27 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.audiovae import AudioVAE, AudioVAEConfig
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
@@ -65,10 +74,31 @@ class VoxCPMConfig(BaseModel):
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
audio_vae_config: Optional[AudioVAEConfig] = None
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
dit_mean_mode: bool = False
|
||||
|
||||
|
||||
class LoRAConfig(BaseModel):
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
|
||||
r: int = 8
|
||||
alpha: int = 16
|
||||
dropout: float = 0.0
|
||||
|
||||
# Target linear layer names for LM & DiT (matched by attribute name)
|
||||
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
# Projection layer attribute names to find on VoxCPMModel
|
||||
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
|
||||
|
||||
|
||||
VoxCPMConfig.model_rebuild()
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
@@ -77,18 +107,24 @@ class VoxCPMModel(nn.Module):
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
@@ -99,7 +135,7 @@ class VoxCPMModel(nn.Module):
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
@@ -123,6 +159,7 @@ class VoxCPMModel(nn.Module):
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
mean_mode=config.dit_mean_mode,
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
@@ -131,7 +168,7 @@ class VoxCPMModel(nn.Module):
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
)
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
@@ -140,29 +177,168 @@ class VoxCPMModel(nn.Module):
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
if self.device == "cuda":
|
||||
if self.lora_config is not None:
|
||||
self._apply_lora()
|
||||
|
||||
def _apply_lora(self):
|
||||
"""注入 LoRA 到 LM / DiT / 投影层"""
|
||||
cfg = self.lora_config
|
||||
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
|
||||
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
apply_lora_to_named_linear_modules(
|
||||
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
|
||||
)
|
||||
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
|
||||
|
||||
def optimize(self, disable: bool = False):
|
||||
if disable:
|
||||
return self
|
||||
try:
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
import triton
|
||||
except:
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
except Exception as e:
|
||||
print(f"Warning: torch.compile disabled - {e}")
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
audio_feats: torch.Tensor,
|
||||
audio_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
progress: float = 0.0,
|
||||
sample_generate: bool = False,
|
||||
):
|
||||
del position_ids # not used yet
|
||||
|
||||
text_tokens = text_tokens.to(self.device, dtype=torch.long)
|
||||
text_mask = text_mask.to(self.device, dtype=self._dtype())
|
||||
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
|
||||
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
|
||||
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
|
||||
labels = labels.to(self.device, dtype=torch.long)
|
||||
|
||||
B, T, P, D = audio_feats.shape
|
||||
feat_embed = self.feat_encoder(audio_feats)
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
|
||||
if not getattr(self.config.lm_config, "use_mup", False):
|
||||
scale_emb = 1.0
|
||||
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
|
||||
enc_outputs = enc_outputs.to(self._dtype())
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
|
||||
|
||||
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
|
||||
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
|
||||
residual_outputs = residual_outputs.to(self._dtype())
|
||||
residual_hidden = torch.cat(
|
||||
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
|
||||
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
|
||||
|
||||
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
|
||||
target_dtype = self._dtype()
|
||||
|
||||
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
|
||||
feat_cond = torch.cat(
|
||||
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
|
||||
dim=1,
|
||||
)
|
||||
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
|
||||
|
||||
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
|
||||
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
|
||||
|
||||
diff_loss = self.feat_decoder.compute_loss(
|
||||
feat_gt.transpose(1, 2).contiguous(),
|
||||
dit_hidden,
|
||||
cond=feat_cond.transpose(1, 2).contiguous(),
|
||||
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
|
||||
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
|
||||
denom = torch.clamp(loss_mask.sum(), min=1.0)
|
||||
stop_loss = (stop_losses * loss_mask).sum() / denom
|
||||
|
||||
feat_pred = None
|
||||
if sample_generate:
|
||||
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
|
||||
feat_pred_seq = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
return {
|
||||
"loss/diff": diff_loss,
|
||||
"loss/stop": stop_loss,
|
||||
"feat_gt": feat_gt_tensor,
|
||||
"feat_pred": feat_pred,
|
||||
}
|
||||
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
|
||||
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||
return self._generate(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
def _generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
@@ -174,7 +350,11 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
@@ -213,25 +393,25 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
@@ -250,33 +430,46 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, _ in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
break
|
||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
@@ -292,13 +485,11 @@ class VoxCPMModel(nn.Module):
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
prompt_cache: dict with prompt_text (raw text) and audio features.
|
||||
Text tokenization will be done during generation for consistency.
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
@@ -311,20 +502,21 @@ class VoxCPMModel(nn.Module):
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"prompt_text": prompt_text,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
@@ -334,7 +526,7 @@ class VoxCPMModel(nn.Module):
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_text: str,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
@@ -342,32 +534,44 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_text: newly generated text
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
merged_cache: merged cache with prompt_text and audio_feat
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"prompt_text": new_text,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_prompt_text = original_cache["prompt_text"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
# Merge text by concatenation
|
||||
merged_prompt_text = original_prompt_text + new_text
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"prompt_text": merged_prompt_text,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
@@ -378,7 +582,8 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
streaming: bool = False,
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
@@ -392,20 +597,27 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
streaming: Whether to return a generator of audio chunks
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
Generator of Tuple containing:
|
||||
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||
- Tensor of new text tokens
|
||||
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
if retry_badcase and streaming:
|
||||
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||
retry_badcase = False
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
text = target_text
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
prompt_text = prompt_cache["prompt_text"]
|
||||
text = prompt_text + target_text
|
||||
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
@@ -417,6 +629,8 @@ class VoxCPMModel(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
@@ -433,42 +647,63 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
inference_result = self._inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
if streaming:
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
for latent_pred, pred_audio_feat in inference_result:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
break
|
||||
else:
|
||||
latent_pred, pred_audio_feat = next(inference_result)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
yield (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return next(self._inference(*args, streaming=False, **kwargs))
|
||||
|
||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._inference(*args, streaming=True, **kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
def _inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
@@ -478,7 +713,9 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
streaming: bool = False,
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
@@ -493,11 +730,12 @@ class VoxCPMModel(nn.Module):
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
streaming: Whether to yield each step latent feature or just the final result
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
Generator of Tuple containing:
|
||||
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
@@ -549,11 +787,18 @@ class VoxCPMModel(nn.Module):
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
@@ -569,37 +814,140 @@ class VoxCPMModel(nn.Module):
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
else: # training mode
|
||||
for name, param in model.named_parameters():
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
param.requires_grad = False
|
||||
continue
|
||||
if lora_config is not None:
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
param.requires_grad = False
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
|
||||
# Try to load from safetensors first, fallback to pytorch_model.bin
|
||||
safetensors_path = os.path.join(path, "model.safetensors")
|
||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||
|
||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading model from safetensors: {safetensors_path}")
|
||||
model_state_dict = load_file(safetensors_path)
|
||||
elif os.path.exists(pytorch_model_path):
|
||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
||||
checkpoint = torch.load(
|
||||
pytorch_model_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
|
||||
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
||||
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
||||
model.load_state_dict(model_state_dict, strict=False)
|
||||
if training:
|
||||
return model
|
||||
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Weight Management
|
||||
# ------------------------------------------------------------------ #
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
|
||||
def load_lora_weights(self, lora_path: str, device: str = None):
|
||||
"""
|
||||
Load LoRA weights from file, supports calling after torch.compile.
|
||||
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
||||
Supports both safetensors and pytorch formats.
|
||||
|
||||
Args:
|
||||
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
||||
device: Target device, defaults to model's current device
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
||||
|
||||
loaded_keys, skipped_keys = [], []
|
||||
for key, value in state_dict.items():
|
||||
target_key = key if key in model_params else key_mapping.get(key)
|
||||
if target_key:
|
||||
model_params[target_key].data.copy_(value.to(device))
|
||||
loaded_keys.append(key)
|
||||
else:
|
||||
skipped_keys.append(key)
|
||||
|
||||
return loaded_keys, skipped_keys
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable/disable all LoRA layers."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.set_enabled(enabled)
|
||||
|
||||
def reset_lora_weights(self):
|
||||
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.reset_lora_parameters()
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 1536
|
||||
decoder_rates: List[int] = [8, 8, 5, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
use_noise_block: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
133
src/voxcpm/modules/layers/lora.py
Normal file
133
src/voxcpm/modules/layers/lora.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
|
||||
|
||||
state_dict 结构:
|
||||
- weight: 原始权重(与 nn.Linear 一致)
|
||||
- bias: 原始偏置(与 nn.Linear 一致)
|
||||
- lora_A: LoRA 低秩矩阵 A
|
||||
- lora_B: LoRA 低秩矩阵 B
|
||||
|
||||
这样设计的好处:加载预训练权重时无需做 key 转换。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: nn.Linear,
|
||||
r: int,
|
||||
alpha: float = 1.0,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
|
||||
|
||||
self.in_features = base.in_features
|
||||
self.out_features = base.out_features
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self._base_scaling = alpha / r if r > 0 else 0.0
|
||||
|
||||
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
||||
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
||||
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
||||
|
||||
# 直接持有 weight 和 bias(从原始 Linear 转移过来)
|
||||
self.weight = base.weight
|
||||
self.bias = base.bias # 可能是 None
|
||||
|
||||
# LoRA 参数
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
else:
|
||||
self.register_parameter("lora_A", None)
|
||||
self.register_parameter("lora_B", None)
|
||||
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 基础 Linear 计算
|
||||
result = F.linear(x, self.weight, self.bias)
|
||||
if self.r <= 0 or self.lora_A is None:
|
||||
return result
|
||||
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
|
||||
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||||
return result + self.dropout(lora_out) * self.scaling
|
||||
|
||||
def reset_lora_parameters(self):
|
||||
"""重置 LoRA 参数到初始状态"""
|
||||
if self.r > 0 and self.lora_A is not None:
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def set_enabled(self, enabled: bool):
|
||||
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
|
||||
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
|
||||
self.scaling.fill_(self._base_scaling if enabled else 0.0)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.scaling.item() != 0.0
|
||||
|
||||
|
||||
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
|
||||
"""
|
||||
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
|
||||
"""
|
||||
parts = name.split(".")
|
||||
if len(parts) == 1:
|
||||
return root
|
||||
parent = root
|
||||
for p in parts[:-1]:
|
||||
if not hasattr(parent, p):
|
||||
return None
|
||||
parent = getattr(parent, p)
|
||||
return parent
|
||||
|
||||
|
||||
def apply_lora_to_named_linear_modules(
|
||||
root: nn.Module,
|
||||
*,
|
||||
target_submodule_names: list[str],
|
||||
r: int,
|
||||
alpha: float,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""
|
||||
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
|
||||
|
||||
例如 target_submodule_names=["q_proj", "v_proj"] 时,
|
||||
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
|
||||
"""
|
||||
for full_name, module in list(root.named_modules()):
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
short_name = full_name.split(".")[-1]
|
||||
if short_name not in target_submodule_names:
|
||||
continue
|
||||
|
||||
parent = _get_parent_module(root, full_name)
|
||||
if parent is None:
|
||||
continue
|
||||
|
||||
# 用 LoRALinear 替换原始 Linear
|
||||
lora_layer = LoRALinear(
|
||||
base=module,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from torch.func import jvp
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
sigma_min: float = 1e-6
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
training_cfg_rate: float = 0.1
|
||||
inference_cfg_rate: float = 1.0
|
||||
reg_loss_type: str = "l1"
|
||||
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
|
||||
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
|
||||
noise_cond_scale: float = 0.0
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
in_channels: int,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.training_cfg_rate = cfm_params.training_cfg_rate
|
||||
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
||||
self.reg_loss_type = cfm_params.reg_loss_type
|
||||
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
|
||||
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
|
||||
self.noise_cond_scale = cfm_params.noise_cond_scale
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Inference
|
||||
# ------------------------------------------------------------------ #
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
b, _ = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
return self.solve_euler(
|
||||
x=z,
|
||||
t_span=t_span,
|
||||
mu=mu,
|
||||
cond=cond,
|
||||
cfg_value=cfg_value,
|
||||
use_cfg_zero_star=use_cfg_zero_star,
|
||||
)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: condition -- prefix prompt
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
dphi_dt = torch.zeros_like(x)
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
return weights.detach()
|
||||
|
||||
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
|
||||
batch_size = x.shape[0]
|
||||
if self.t_scheduler == "log-norm":
|
||||
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
r = torch.sigmoid(s_r)
|
||||
t = torch.sigmoid(s_t)
|
||||
elif self.t_scheduler == "uniform":
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
|
||||
|
||||
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
|
||||
r, t = torch.where(
|
||||
mask,
|
||||
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
|
||||
torch.stack([t, t], dim=0),
|
||||
)
|
||||
|
||||
return r.squeeze(), t.squeeze()
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
x1: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
tgt_mask: torch.Tensor | None = None,
|
||||
progress: float = 0.0,
|
||||
):
|
||||
b, _, _ = x1.shape
|
||||
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1)
|
||||
|
||||
if cond is None:
|
||||
cond = torch.zeros_like(x1)
|
||||
|
||||
noisy_mask = torch.rand(b, device=x1.device) > (
|
||||
1.0
|
||||
- (
|
||||
self.noise_cond_prob_range[0]
|
||||
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
|
||||
)
|
||||
)
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
|
||||
r_ = r.detach().clone()
|
||||
t_ = t.detach().clone()
|
||||
z = torch.randn_like(x1)
|
||||
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
|
||||
v = z - x1
|
||||
|
||||
def model_fn(z_sample, r_sample, t_sample):
|
||||
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
|
||||
|
||||
if self.mean_mode:
|
||||
v_r = torch.zeros_like(r)
|
||||
v_t = torch.ones_like(t)
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
|
||||
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
|
||||
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
|
||||
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
|
||||
else:
|
||||
u_pred = model_fn(y, r, t)
|
||||
u_tgt = v
|
||||
|
||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||
if tgt_mask is not None:
|
||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
||||
else:
|
||||
loss = losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_cache = key_cache.contiguous()
|
||||
value_cache = value_cache.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
|
||||
28
src/voxcpm/training/__init__.py
Normal file
28
src/voxcpm/training/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Training utilities for VoxCPM fine-tuning.
|
||||
|
||||
This package mirrors the training mechanics used in the minicpm-audio
|
||||
tooling while relying solely on local audio-text datasets managed via
|
||||
the HuggingFace ``datasets`` library.
|
||||
"""
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .tracker import TrainingTracker
|
||||
from .data import (
|
||||
load_audio_text_datasets,
|
||||
HFVoxCPMDataset,
|
||||
build_dataloader,
|
||||
BatchProcessor,
|
||||
)
|
||||
from .state import TrainingState
|
||||
|
||||
__all__ = [
|
||||
"Accelerator",
|
||||
"TrainingTracker",
|
||||
"HFVoxCPMDataset",
|
||||
"BatchProcessor",
|
||||
"TrainingState",
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
166
src/voxcpm/training/accelerator.py
Normal file
166
src/voxcpm/training/accelerator.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
class Accelerator:
|
||||
"""
|
||||
Simplified accelerator that mirrors the behaviour of the minicpm-audio
|
||||
training utilities. It initializes a distributed process group when
|
||||
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
|
||||
preparing models/dataloaders for DDP.
|
||||
"""
|
||||
|
||||
def __init__(self, amp: bool = False, seed: int = 42):
|
||||
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
if self.world_size > 1 and not dist.is_initialized():
|
||||
dist.init_process_group("nccl", init_method="env://")
|
||||
|
||||
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
self.amp = amp
|
||||
|
||||
# Set random seed to ensure model initialization consistency
|
||||
self._set_seed(seed)
|
||||
|
||||
class DummyScaler:
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
return optimizer
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def __enter__(self):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def barrier(self):
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
|
||||
"""All-reduce tensor across processes"""
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
|
||||
self._ddp_model = model # Save DDP model reference for no_sync support
|
||||
return model
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_sync(self):
|
||||
"""
|
||||
Context manager to skip gradient synchronization during gradient accumulation.
|
||||
Only used outside the last micro-batch.
|
||||
"""
|
||||
if self._ddp_model is not None:
|
||||
with self._ddp_model.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda", self.local_rank)
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# AMP helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def autocast(self, *args, **kwargs):
|
||||
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
def step(self, optimizer: torch.optim.Optimizer):
|
||||
self.scaler.step(optimizer)
|
||||
|
||||
def update(self):
|
||||
self.scaler.update()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Data helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
dataset: typing.Iterable,
|
||||
*,
|
||||
batch_size: int,
|
||||
num_workers: int = 0,
|
||||
shuffle: bool = True,
|
||||
collate_fn=None,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
if self.world_size > 1:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
|
||||
)
|
||||
shuffle = False
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
40
src/voxcpm/training/config.py
Normal file
40
src/voxcpm/training/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argbind
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a YAML configuration file into a dictionary suitable for argbind.
|
||||
"""
|
||||
path = Path(path)
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
|
||||
return data
|
||||
|
||||
|
||||
def parse_args_with_config(config_path: str | Path | None = None):
|
||||
"""
|
||||
Helper to unify CLI arguments and YAML configuration.
|
||||
|
||||
Usage mirrors minicpm-audio:
|
||||
args = parse_args_with_config("conf/voxcpm/finetune.yml")
|
||||
with argbind.scope(args):
|
||||
...
|
||||
"""
|
||||
cli_args = argbind.parse_args()
|
||||
if config_path is None:
|
||||
return cli_args
|
||||
|
||||
yaml_args = load_yaml_config(config_path)
|
||||
with argbind.scope(cli_args):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
214
src/voxcpm/training/data.py
Normal file
214
src/voxcpm/training/data.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
from datasets import Audio, Dataset, DatasetDict, load_dataset
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@argbind.bind()
|
||||
def load_audio_text_datasets(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
) -> Tuple[Dataset, Optional[Dataset]]:
|
||||
data_files = {"train": train_manifest}
|
||||
if val_manifest:
|
||||
data_files["validation"] = val_manifest
|
||||
|
||||
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
|
||||
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
else:
|
||||
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
|
||||
return ds
|
||||
|
||||
train_ds = prepare(dataset_dict["train"])
|
||||
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
|
||||
return train_ds, val_ds
|
||||
|
||||
|
||||
def compute_sample_lengths(
|
||||
ds: Dataset,
|
||||
audio_vae_fps: int = 25,
|
||||
patch_size: int = 1,
|
||||
) -> List[int]:
|
||||
"""
|
||||
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
|
||||
|
||||
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
|
||||
- 文本长度: len(text_ids)
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
"""
|
||||
lengths: List[int] = []
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
text_len = len(item["text_ids"])
|
||||
|
||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
||||
if has_duration:
|
||||
duration = float(item["duration"])
|
||||
else:
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
||||
|
||||
t_vae = math.ceil(duration * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
|
||||
total_len = text_len + t_seq + 2
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
|
||||
|
||||
class HFVoxCPMDataset(TorchDataset):
|
||||
"""
|
||||
Thin wrapper around a tokenized HuggingFace dataset that returns
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
if not seqs:
|
||||
return torch.empty(0)
|
||||
max_len = max(seq.shape[0] for seq in seqs)
|
||||
padded = []
|
||||
for seq in seqs:
|
||||
if seq.shape[0] < max_len:
|
||||
pad_width = (0, max_len - seq.shape[0])
|
||||
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
|
||||
padded.append(seq)
|
||||
return torch.stack(padded)
|
||||
|
||||
@classmethod
|
||||
def collate_fn(cls, batch: List[Dict]):
|
||||
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
|
||||
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
|
||||
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
|
||||
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
|
||||
|
||||
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
"dataset_ids": dataset_ids,
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
|
||||
the minicpm-audio mechanics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: VoxCPMConfig,
|
||||
audio_vae: AudioVAE,
|
||||
dataset_cnt: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
self.dataset_cnt = dataset_cnt
|
||||
self.audio_vae = audio_vae
|
||||
self.audio_vae.to(device)
|
||||
self.packer = AudioFeatureProcessingPacker(
|
||||
dataset_cnt=dataset_cnt,
|
||||
max_len=config.max_length,
|
||||
patch_size=config.patch_size,
|
||||
feat_dim=config.feat_dim,
|
||||
audio_vae=self.audio_vae,
|
||||
)
|
||||
|
||||
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
audio_tokens = batch["audio_tokens"].to(self.device)
|
||||
text_tokens = batch["text_tokens"].to(self.device)
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
def build_dataloader(
|
||||
hf_dataset: Dataset,
|
||||
*,
|
||||
accelerator,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
torch_dataset = HFVoxCPMDataset(hf_dataset)
|
||||
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
|
||||
return accelerator.prepare_dataloader(
|
||||
torch_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
289
src/voxcpm/training/packers.py
Normal file
289
src/voxcpm/training/packers.py
Normal file
@@ -0,0 +1,289 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class AudioFeatureProcessingPacker:
|
||||
"""
|
||||
Adapted from the minicpm-audio training utilities. It converts raw text and
|
||||
audio tokens into the packed multimodal representation required by VoxCPM.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_len = audio_vae.hop_length * self.patch_size
|
||||
self.feat_dim = feat_dim
|
||||
self.dataset_cnt = max(dataset_cnt, 1)
|
||||
self.max_len = max_len
|
||||
|
||||
self.audio_vae = audio_vae
|
||||
|
||||
self.process_functions = {"tts": self.process_tts_data}
|
||||
self.task_id_map = {"tts": 1}
|
||||
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _first_pad_position(tokens: torch.Tensor):
|
||||
positions = (tokens == -100).nonzero(as_tuple=True)
|
||||
if positions[0].numel() == 0:
|
||||
return None
|
||||
return int(positions[0][0])
|
||||
|
||||
def unpad_text_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def unpad_audio_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def encode_audio(self, wav: torch.Tensor):
|
||||
"""
|
||||
Encode raw waveform into latent features using AudioVAE.
|
||||
|
||||
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
|
||||
We then transpose to [B, T, D] to match downstream expectations.
|
||||
"""
|
||||
wav = wav.unsqueeze(0) # [1, T]
|
||||
wav = wav.unsqueeze(1) # [1, 1, T]
|
||||
wav_len = wav.size(-1)
|
||||
if wav_len % self.patch_len != 0:
|
||||
padding_size = self.patch_len - wav_len % self.patch_len
|
||||
wav = torch.nn.functional.pad(wav, (0, padding_size))
|
||||
|
||||
with torch.no_grad():
|
||||
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
|
||||
feat = z.transpose(1, 2) # [1, T', D]
|
||||
return feat
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Main entry point
|
||||
# ------------------------------------------------------------------ #
|
||||
def __call__(
|
||||
self,
|
||||
audio_tokens: torch.Tensor,
|
||||
text_tokens: torch.Tensor,
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
|
||||
|
||||
text_tokens_list: List[torch.Tensor] = []
|
||||
audio_feats_list: List[torch.Tensor] = []
|
||||
text_mask_list: List[torch.Tensor] = []
|
||||
audio_mask_list: List[torch.Tensor] = []
|
||||
loss_mask_list: List[torch.Tensor] = []
|
||||
labels_list: List[torch.Tensor] = []
|
||||
audio_task_ids_list: List[torch.Tensor] = []
|
||||
audio_dataset_ids_list: List[torch.Tensor] = []
|
||||
lengths: List[int] = []
|
||||
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
|
||||
audio_duration_consumed[dataset_idx] += audio_duration
|
||||
text_token_consumed[dataset_idx] += text_token_count
|
||||
|
||||
audio_task_id = torch.zeros_like(audio_mask)
|
||||
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
|
||||
|
||||
audio_dataset_id = torch.zeros_like(audio_mask)
|
||||
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
|
||||
|
||||
text_tokens_list.append(packed_text)
|
||||
text_mask_list.append(text_mask)
|
||||
audio_feats_list.append(audio_feat)
|
||||
audio_mask_list.append(audio_mask)
|
||||
loss_mask_list.append(loss_mask)
|
||||
labels_list.append(labels)
|
||||
audio_task_ids_list.append(audio_task_id)
|
||||
audio_dataset_ids_list.append(audio_dataset_id)
|
||||
lengths.append(packed_text.shape[0])
|
||||
|
||||
# Determine padded length per batch (cap by self.max_len)
|
||||
if lengths:
|
||||
max_len = min(self.max_len, max(lengths))
|
||||
else:
|
||||
max_len = self.max_len
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
for L in lengths:
|
||||
L_clip = min(L, max_len)
|
||||
pos = torch.arange(0, L_clip, device=device)
|
||||
if L_clip < max_len:
|
||||
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
|
||||
pos = torch.cat([pos, pad], dim=0)
|
||||
position_ids_list.append(pos)
|
||||
position_ids = torch.stack(position_ids_list, dim=0)
|
||||
else:
|
||||
# Empty batch fallback (shouldn't really happen)
|
||||
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
|
||||
text_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_feats_batch = torch.zeros(
|
||||
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
|
||||
)
|
||||
audio_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
loss_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
labels_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
position_ids = torch.zeros_like(text_tokens_batch)
|
||||
|
||||
audio_duration_consumed = audio_duration_consumed.to(torch.long)
|
||||
text_token_consumed = text_token_consumed.to(torch.long)
|
||||
|
||||
return {
|
||||
"text_tokens": text_tokens_batch,
|
||||
"audio_feats": audio_feats_batch,
|
||||
"text_mask": text_mask_batch,
|
||||
"audio_mask": audio_mask_batch,
|
||||
"loss_mask": loss_mask_batch,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels_batch,
|
||||
"audio_task_ids": audio_task_ids_batch,
|
||||
"audio_dataset_ids": audio_dataset_ids_batch,
|
||||
"audio_duration_consumed": audio_duration_consumed,
|
||||
"text_token_consumed": text_token_consumed,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Feature extraction helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def extract_audio_feats(self, audio_data: torch.Tensor):
|
||||
audio_feats = self.encode_audio(audio_data)
|
||||
if audio_feats.size(1) % self.patch_size != 0:
|
||||
audio_feats_ = audio_feats.transpose(1, 2)
|
||||
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
|
||||
audio_feats = padding.transpose(1, 2)
|
||||
|
||||
audio_duration = audio_feats.size(1) / 25
|
||||
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
|
||||
return audio_feats, audio_duration
|
||||
|
||||
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_token_count = len(text_token)
|
||||
text_length = text_token_info.shape[0]
|
||||
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
|
||||
audio_feat_info = audio_feat_info.squeeze(0)
|
||||
audio_length = audio_feat_info.shape[0]
|
||||
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token_info,
|
||||
text_pad_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
]
|
||||
)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, audio_feat_info.size(-1)),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
21
src/voxcpm/training/state.py
Normal file
21
src/voxcpm/training/state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""
|
||||
Container that mirrors the object returned in the minicpm-audio training
|
||||
loop. It holds persistent references to the model, optimizer, scheduler,
|
||||
dataloaders and tracker.
|
||||
"""
|
||||
|
||||
generator: object
|
||||
optimizer: object
|
||||
scheduler: object
|
||||
train_loader: object
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
78
src/voxcpm/training/tracker.py
Normal file
78
src/voxcpm/training/tracker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class TrainingTracker:
|
||||
"""
|
||||
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
||||
|
||||
It keeps track of the current global step, prints rank-aware messages,
|
||||
optionally writes to TensorBoard via a provided writer, and stores progress
|
||||
in a logfile for later inspection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
writer=None,
|
||||
log_file: Optional[str] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
self.writer = writer
|
||||
self.log_file = Path(log_file) if log_file else None
|
||||
if self.log_file:
|
||||
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rank = rank
|
||||
self.step = 0
|
||||
# Record the time of the last log to calculate the interval
|
||||
self._last_log_time: float | None = None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Logging helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def print(self, message: str):
|
||||
if self.rank == 0:
|
||||
print(message, flush=True)
|
||||
if self.log_file:
|
||||
with self.log_file.open("a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, float], split: str):
|
||||
if self.rank == 0:
|
||||
now = time.time()
|
||||
dt_str = ""
|
||||
if self._last_log_time is not None:
|
||||
dt = now - self._last_log_time
|
||||
dt_str = f", log interval: {dt:.2f}s"
|
||||
self._last_log_time = now
|
||||
|
||||
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
||||
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
||||
if self.writer is not None:
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
||||
|
||||
def done(self, split: str, message: str):
|
||||
self.print(f"[{split}] {message}")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# State dict
|
||||
# ------------------------------------------------------------------ #
|
||||
def state_dict(self):
|
||||
return {"step": self.step}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.step = int(state.get("step", 0))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Context manager compatibility (for parity with minicpm-audio code)
|
||||
# ------------------------------------------------------------------ #
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
@@ -3,41 +3,8 @@ import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
from wetext import Normalizer
|
||||
|
||||
def normal_cut_sentence(text):
|
||||
# 先处理括号内的逗号,将其替换为特殊标记
|
||||
text = re.sub(r'([((][^))]*)([,,])([^))]*[))])', r'\1&&&\3', text)
|
||||
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
|
||||
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
|
||||
text = re.sub('(\…{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
|
||||
text = re.sub('([. ,。!;?\?\.{6}\…{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
|
||||
# 处理英文句子的分隔
|
||||
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
|
||||
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
|
||||
text = re.sub(r'([((][^))]*)(&&&)([^))]*[))])', r'\1,\3', text)
|
||||
text = [t for t in text.split("\n") if t]
|
||||
return text
|
||||
|
||||
|
||||
def cut_sentence_with_fix_length(text : str, length : int):
|
||||
sentences = normal_cut_sentence(text)
|
||||
cur_length = 0
|
||||
res = ""
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if cur_length > length or cur_length + len(sentence) > length:
|
||||
yield res
|
||||
res = ""
|
||||
cur_length = 0
|
||||
res += sentence
|
||||
cur_length += len(sentence)
|
||||
if res:
|
||||
yield res
|
||||
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
|
||||
# whether contain chinese character
|
||||
@@ -195,8 +162,8 @@ def clean_text(text):
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
|
||||
self.en_tn_model = Normalizer(lang="en", operator="tn")
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
def normalize(self, text, split=False):
|
||||
@@ -207,38 +174,12 @@ class TextNormalizer:
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,]+$', '。', text)
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if split is False:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_normalizer = TextNormalizer()
|
||||
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
|
||||
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
|
||||
其中,aaa、bbb 和 ccc 是常数,xxx 是变量。这个方程的解可以通过求根公式来找到。
|
||||
一元二次方程的解法有几种:
|
||||
- 因式分解法:通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式,解决方程的解。比如,方程x2−5x+6=0x^2 - 5x + 6 = 0x2−5x+6=0可以因式分解为(x−2)(x−3)=0(x - 2)(x - 3) = 0(x−2)(x−3)=0,因此根为2和3。
|
||||
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
|
||||
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
|
||||
公式:x=−b±b2−4ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2a−b±b2−4ac这个公式可以帮助我们求解任何一元二次方程的根。
|
||||
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b2−4ac\Delta = b^2 - 4acΔ=b2−4ac 根据判别式的值,我们可以知道:
|
||||
- 如果 Δ>0\Delta > 0Δ>0,方程有两个不相等的实数解。这是因为判别式大于0时,根号内的值是正数,所以我们可以得到两个不同的解。
|
||||
- 如果 Δ=0\Delta = 0Δ=0,方程有一个实数解。这是因为根号内的值为零,导致两个解相等,也就是说方程有一个解。
|
||||
- 如果 Δ<0\Delta < 0Δ<0,方程没有实数解。这意味着根号内的值是负数,无法进行实数运算,因此方程没有实数解,可能有复数解。"""
|
||||
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战,", "550 + 320 = 870千卡。", "解一元二次方程:3x^2+x-2=0", "你好啊"]
|
||||
texts = [text]
|
||||
for text in texts:
|
||||
text = text_normalizer.normalize(text)
|
||||
print(text)
|
||||
for t in cut_sentence_with_fix_length(text, 15):
|
||||
print(t)
|
||||
return text
|
||||
Reference in New Issue
Block a user