mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 19:58:12 +00:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1bb6aaf41 | ||
|
|
2eb4d39719 | ||
|
|
fbf8984d4e | ||
|
|
41752dc0fa | ||
|
|
b0714adcaa | ||
|
|
89f4d917a0 | ||
|
|
5c5da0dbe6 | ||
|
|
961569e76d | ||
|
|
5f56d5ff5d | ||
|
|
169c17ddfd | ||
|
|
996c69a1a8 | ||
|
|
dc6b6d1d1c | ||
|
|
cef6aefb3d | ||
|
|
1a46c5d1ad | ||
|
|
5257ec3dc5 | ||
|
|
bdd516b579 | ||
|
|
11568f0776 | ||
|
|
e5bcb735f0 | ||
|
|
f26a1ea2f7 | ||
|
|
1fa9e2ca02 | ||
|
|
10f48ba330 | ||
|
|
639b2272ab | ||
|
|
7e8f754ba1 | ||
|
|
032c7fe403 | ||
|
|
5390a47862 | ||
|
|
e7012f1a94 | ||
|
|
82332cfc99 | ||
|
|
605ac2d8e4 | ||
|
|
0fa8d894d1 | ||
|
|
776c0d19fb | ||
|
|
ed6e6b4dac | ||
|
|
e3108d4a12 | ||
|
|
59fe3f30a1 | ||
|
|
6f2fb45756 | ||
|
|
91128d823d | ||
|
|
436e8cd6e5 | ||
|
|
11574ae93d | ||
|
|
706403187e | ||
|
|
38a76704ee | ||
|
|
dfd487f5af | ||
|
|
081845b35b | ||
|
|
6684b547cc |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
launch.json
|
||||||
|
__pycache__
|
||||||
|
voxcpm.egg-info
|
||||||
134
README.md
134
README.md
@@ -1,25 +1,33 @@
|
|||||||
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
|
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
|
||||||
|
|
||||||
|
|
||||||
[](https://github.com/OpenBMB/VoxCPM/) [](hhttps://huggingface.co/openbmb/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://thuhcsi.github.io/VoxCPM/)
|
[](https://github.com/OpenBMB/VoxCPM/) [](https://arxiv.org/abs/2509.24650) [](https://huggingface.co/openbmb/VoxCPM-0.5B) [](https://modelscope.cn/models/OpenBMB/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://openbmb.github.io/VoxCPM-demopage)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%">
|
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
👋 Contact us on [WeChat](assets/wechat.png)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
## News
|
## News
|
||||||
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B weights!
|
* [2025.09.30] 🔥 🔥 🔥 We Release VoxCPM [Technical Report](https://arxiv.org/abs/2509.24650)!
|
||||||
|
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B [weights](https://huggingface.co/openbmb/VoxCPM-0.5B)!
|
||||||
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
|
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
VoxCPM is a novel tokenizer-free Text-to-Speech (TTS) system that redefines realism in speech synthesis. By modeling speech in a continuous space, it overcomes the limitations of discrete tokenization and enables two flagship capabilities: context-aware speech generation and true-to-life zero-shot voice cloning.
|
VoxCPM is a novel tokenizer-free Text-to-Speech (TTS) system that redefines realism in speech synthesis. By modeling speech in a continuous space, it overcomes the limitations of discrete tokenization and enables two flagship capabilities: context-aware speech generation and true-to-life zero-shot voice cloning.
|
||||||
|
|
||||||
Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on [MiniCPM-4](https://huggingface.co/openbmb/MiniCPM4-0.5B), it achieves implicit semantic-acoustic decoupling through hierachical language modeling and FSQ constraints, greatly enhancing both expressiveness and generation stability.
|
Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on [MiniCPM-4](https://huggingface.co/openbmb/MiniCPM4-0.5B) backbone, it achieves implicit semantic-acoustic decoupling through hierachical language modeling and FSQ constraints, greatly enhancing both expressiveness and generation stability.
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="assets/voxcpm_model.png" alt="VoxCPM Model Architecture" width="500">
|
<img src="assets/voxcpm_model.png" alt="VoxCPM Model Architecture" width="90%">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +38,8 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### 🔧 Install from PyPI
|
### 🔧 Install from PyPI
|
||||||
@@ -41,7 +51,7 @@ By default, when you first run the script, the model will be downloaded automati
|
|||||||
- Download VoxCPM-0.5B
|
- Download VoxCPM-0.5B
|
||||||
```
|
```
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
|
snapshot_download("openbmb/VoxCPM-0.5B")
|
||||||
```
|
```
|
||||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||||
```
|
```
|
||||||
@@ -53,25 +63,39 @@ By default, when you first run the script, the model will be downloaded automati
|
|||||||
### 2. Basic Usage
|
### 2. Basic Usage
|
||||||
```python
|
```python
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
from voxcpm import VoxCPM
|
from voxcpm import VoxCPM
|
||||||
|
|
||||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
wav = model.generate(
|
wav = model.generate(
|
||||||
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
||||||
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
||||||
prompt_text=None, # optional: reference text
|
prompt_text=None, # optional: reference text
|
||||||
cfg_value=2.0,
|
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
||||||
inference_timesteps=10,
|
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
||||||
normalize=True,
|
normalize=True, # enable external TN tool
|
||||||
denoise=True,
|
denoise=True, # enable external Denoise tool
|
||||||
retry_badcase=True, # optional: enable retrying mode
|
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
||||||
retry_badcase_max_times=3,
|
retry_badcase_max_times=3, # maximum retrying times
|
||||||
retry_badcase_ratio_threshold=6.0,
|
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
|
||||||
)
|
)
|
||||||
|
|
||||||
sf.write("output.wav", wav, 16000)
|
sf.write("output.wav", wav, 16000)
|
||||||
print("saved: output.wav")
|
print("saved: output.wav")
|
||||||
|
|
||||||
|
# Streaming
|
||||||
|
chunks = []
|
||||||
|
for chunk in model.generate_streaming(
|
||||||
|
text = "Streaming text to speech is easy with VoxCPM!",
|
||||||
|
# supports same args as above
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
wav = np.concatenate(chunks)
|
||||||
|
|
||||||
|
sf.write("output_streaming.wav", wav, 16000)
|
||||||
|
print("saved: output_streaming.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. CLI Usage
|
### 3. CLI Usage
|
||||||
@@ -80,15 +104,22 @@ After installation, the entry point is `voxcpm` (or use `python -m voxcpm.cli`).
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 1) Direct synthesis (single text)
|
# 1) Direct synthesis (single text)
|
||||||
voxcpm --text "Hello VoxCPM" --output out.wav
|
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." --output out.wav
|
||||||
|
|
||||||
# 2) Voice cloning (reference audio + transcript)
|
# 2) Voice cloning (reference audio + transcript)
|
||||||
voxcpm --text "Hello" \
|
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||||
--prompt-audio path/to/voice.wav \
|
--prompt-audio path/to/voice.wav \
|
||||||
--prompt-text "reference transcript" \
|
--prompt-text "reference transcript" \
|
||||||
--output out.wav \
|
--output out.wav \
|
||||||
--denoise
|
--denoise
|
||||||
|
|
||||||
|
# (Optinal) Voice cloning (reference audio + transcript file)
|
||||||
|
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||||
|
--prompt-audio path/to/voice.wav \
|
||||||
|
--prompt-file "/path/to/text-file" \
|
||||||
|
--output out.wav \
|
||||||
|
--denoise
|
||||||
|
|
||||||
# 3) Batch processing (one text per line)
|
# 3) Batch processing (one text per line)
|
||||||
voxcpm --input examples/input.txt --output-dir outs
|
voxcpm --input examples/input.txt --output-dir outs
|
||||||
# (optional) Batch + cloning
|
# (optional) Batch + cloning
|
||||||
@@ -165,6 +196,19 @@ Happy creating! 🎉 Start with the default settings and tweak from there to sui
|
|||||||
---
|
---
|
||||||
|
|
||||||
|
|
||||||
|
## 🌟 Community Projects
|
||||||
|
|
||||||
|
We're excited to see the VoxCPM community growing! Here are some amazing projects and features built by our community:
|
||||||
|
|
||||||
|
- **[ComfyUI-VoxCPM](https://github.com/wildminder/ComfyUI-VoxCPM)**
|
||||||
|
- **[ComfyUI-VoxCPMTTS](https://github.com/1038lab/ComfyUI-VoxCPMTTS)**
|
||||||
|
- **[WebUI-VoxCPM](https://github.com/rsxdalv/tts_webui_extension.vox_cpm)**
|
||||||
|
- **[PR: Streaming API Support (by AbrahamSanders)](https://github.com/OpenBMB/VoxCPM/pull/26)**
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
*Have you built something cool with VoxCPM? We'd love to feature it here! Please open an issue or pull request to add your project.*
|
||||||
|
|
||||||
|
|
||||||
## 📊 Performance Highlights
|
## 📊 Performance Highlights
|
||||||
|
|
||||||
@@ -175,41 +219,41 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
|||||||
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
||||||
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
||||||
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
||||||
|
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
||||||
|
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
||||||
|
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
||||||
|
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
||||||
|
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
||||||
|
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
||||||
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
||||||
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | 6.83 | 72.4 |
|
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | **6.83** | 72.4 |
|
||||||
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
||||||
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
||||||
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
||||||
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
||||||
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | 74.7 |
|
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | **74.7** |
|
||||||
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
|
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
|
||||||
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
|
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
|
||||||
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
||||||
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
|
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
|
||||||
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
| **VoxCPM** | 0.5B | ✅ | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
|
||||||
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
|
||||||
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
|
||||||
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
|
||||||
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
|
||||||
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
|
||||||
| **VoxCPM** | **0.5B** | **✅** | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
|
|
||||||
|
|
||||||
|
|
||||||
### CV3-eval Benchmark
|
### CV3-eval Benchmark
|
||||||
|
|
||||||
| Model | zh | en | hard-zh | | | hard-en | | | |
|
| Model | zh | en | hard-zh | | | hard-en | | |
|
||||||
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|:--:|
|
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|
|
||||||
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ | |
|
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ |
|
||||||
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - | |
|
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - |
|
||||||
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - | |
|
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - |
|
||||||
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - | |
|
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - |
|
||||||
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 | |
|
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 |
|
||||||
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 | |
|
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 |
|
||||||
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | fail | fail | fail | |
|
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | - | - | - |
|
||||||
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 | |
|
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 |
|
||||||
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 | |
|
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 |
|
||||||
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 | |
|
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 |
|
||||||
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 | |
|
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -231,6 +275,13 @@ VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 📝TO-DO List
|
||||||
|
Please stay tuned for updates!
|
||||||
|
- [x] Release the VoxCPM technical report.
|
||||||
|
- [ ] Support higher sampling rate (next version).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 📄 License
|
## 📄 License
|
||||||
The VoxCPM model weights and code are open-sourced under the [Apache-2.0](LICENSE) license.
|
The VoxCPM model weights and code are open-sourced under the [Apache-2.0](LICENSE) license.
|
||||||
|
|
||||||
@@ -251,6 +302,8 @@ This project is developed by the following institutions:
|
|||||||
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
||||||
|
|
||||||
|
|
||||||
|
## ⭐ Star History
|
||||||
|
[](https://star-history.com/#OpenBMB/VoxCPM&Date)
|
||||||
|
|
||||||
|
|
||||||
## 📚 Citation
|
## 📚 Citation
|
||||||
@@ -258,11 +311,10 @@ This project is developed by the following institutions:
|
|||||||
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
|
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
|
||||||
|
|
||||||
```bib
|
```bib
|
||||||
@misc{voxcpm2025,
|
@article{voxcpm2025,
|
||||||
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}},
|
title = {VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning},
|
||||||
title = {{VoxCPM}},
|
author = {Zhou, Yixuan and Zeng, Guoyang and Liu, Xin and Li, Xiang and Yu, Renjie and Wang, Ziyang and Ye, Runchuan and Sun, Weiyue and Gui, Jiancheng and Li, Kehan and Wu, Zhiyong and Liu, Zhiyuan},
|
||||||
|
journal = {arXiv preprint arXiv:2509.24650},
|
||||||
year = {2025},
|
year = {2025},
|
||||||
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
|
|
||||||
note = {GitHub repository}
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
11
app.py
11
app.py
@@ -170,7 +170,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
|
|
||||||
# Pro Tips
|
# Pro Tips
|
||||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||||
gr.Markdown(f"""
|
gr.Markdown("""
|
||||||
### Prompt Speech Enhancement|参考语音降噪
|
### Prompt Speech Enhancement|参考语音降噪
|
||||||
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
|
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
|
||||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
|
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
|
||||||
@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
**调低**:合成速度更快。
|
**调低**:合成速度更快。
|
||||||
- **Higher** for better synthesis quality.
|
- **Higher** for better synthesis quality.
|
||||||
**调高**:合成质量更佳。
|
**调高**:合成质量更佳。
|
||||||
|
|
||||||
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
|
|
||||||
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
|
|
||||||
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Main controls
|
# Main controls
|
||||||
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
prompt_wav = gr.Audio(
|
prompt_wav = gr.Audio(
|
||||||
sources=["upload", 'microphone'],
|
sources=["upload", 'microphone'],
|
||||||
type="filepath",
|
type="filepath",
|
||||||
label="Prompt Speech",
|
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||||
value="./examples/example.wav",
|
value="./examples/example.wav",
|
||||||
)
|
)
|
||||||
DoDenoisePromptAudio = gr.Checkbox(
|
DoDenoisePromptAudio = gr.Checkbox(
|
||||||
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
text = gr.Textbox(
|
text = gr.Textbox(
|
||||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||||
label="Target Text",
|
label="Target Text",
|
||||||
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
|
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
DoNormalizeText = gr.Checkbox(
|
DoNormalizeText = gr.Checkbox(
|
||||||
value=False,
|
value=False,
|
||||||
label="Text Normalization",
|
label="Text Normalization",
|
||||||
elem_id="chk_normalize",
|
elem_id="chk_normalize",
|
||||||
info="We use WeTextPorcessing library to normalize the input text."
|
info="We use wetext library to normalize the input text."
|
||||||
)
|
)
|
||||||
audio_output = gr.Audio(label="Output Audio")
|
audio_output = gr.Audio(label="Output Audio")
|
||||||
|
|
||||||
|
|||||||
BIN
assets/wechat.png
Normal file
BIN
assets/wechat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.5 KiB |
@@ -20,29 +20,26 @@ classifiers = [
|
|||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch==2.5.1",
|
"torch>=2.5.0",
|
||||||
"torchaudio==2.5.1",
|
"torchaudio>=2.5.0",
|
||||||
"transformers==4.50.1",
|
"transformers>=4.36.2",
|
||||||
"einops",
|
"einops",
|
||||||
"gradio",
|
"gradio",
|
||||||
"inflect",
|
"inflect",
|
||||||
"WeTextProcessing",
|
|
||||||
"addict",
|
"addict",
|
||||||
"modelscope==1.22.0",
|
"wetext",
|
||||||
"simplejson",
|
"modelscope>=1.22.0",
|
||||||
"datasets==2.18.0",
|
"datasets>=3,<4",
|
||||||
"sortedcontainers",
|
|
||||||
"librosa",
|
|
||||||
"huggingface-hub",
|
"huggingface-hub",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"simplejson",
|
||||||
|
"sortedcontainers",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"funasr",
|
"funasr",
|
||||||
"spaces"
|
"spaces"
|
||||||
@@ -69,7 +66,7 @@ Documentation = "https://github.com/OpenBMB/VoxCPM#readme"
|
|||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
include = ["voxcpm"]
|
include = ["voxcpm*"]
|
||||||
|
|
||||||
[tool.setuptools.package-dir]
|
[tool.setuptools.package-dir]
|
||||||
"" = "src"
|
"" = "src"
|
||||||
@@ -79,7 +76,7 @@ version_scheme = "post-release"
|
|||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
target-version = ['py38']
|
target-version = ['py310']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
extend-exclude = '''
|
extend-exclude = '''
|
||||||
/(
|
/(
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
torch==2.5.1
|
|
||||||
torchaudio==2.5.1
|
|
||||||
transformers==4.50.1
|
|
||||||
einops
|
|
||||||
gradio
|
|
||||||
inflect
|
|
||||||
WeTextProcessing
|
|
||||||
addicts
|
|
||||||
modelscope==1.22.0
|
|
||||||
simplejson
|
|
||||||
datasets==2.18.0
|
|
||||||
addicts
|
|
||||||
sortedcontainers
|
|
||||||
librosa
|
|
||||||
huggingface-hub
|
|
||||||
spaces
|
|
||||||
@@ -240,6 +240,7 @@ Examples:
|
|||||||
# Prompt audio (for voice cloning)
|
# Prompt audio (for voice cloning)
|
||||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||||
|
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
||||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||||
|
|
||||||
# Generation parameters
|
# Generation parameters
|
||||||
@@ -279,6 +280,12 @@ def main():
|
|||||||
|
|
||||||
# If prompt audio+text provided → voice cloning
|
# If prompt audio+text provided → voice cloning
|
||||||
if args.prompt_audio or args.prompt_text:
|
if args.prompt_audio or args.prompt_text:
|
||||||
|
if not args.prompt_text and args.prompt_file:
|
||||||
|
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
||||||
|
|
||||||
|
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
||||||
|
args.prompt_text = f.read()
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
if not args.prompt_audio or not args.prompt_text:
|
||||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from modelscope.pipelines import pipeline
|
import numpy as np
|
||||||
from modelscope.utils.constant import Tasks
|
from typing import Generator
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel
|
from .model.voxcpm import VoxCPMModel
|
||||||
from .utils.text_normalize import TextNormalizer
|
|
||||||
|
|
||||||
|
|
||||||
class VoxCPM:
|
class VoxCPM:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
voxcpm_model_path : str,
|
voxcpm_model_path : str,
|
||||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser : bool = True,
|
enable_denoiser : bool = True,
|
||||||
|
optimize: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize VoxCPM TTS pipeline.
|
"""Initialize VoxCPM TTS pipeline.
|
||||||
|
|
||||||
@@ -24,19 +22,20 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||||
self.text_normalizer = TextNormalizer()
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
self.denoiser = pipeline(
|
from .zipenhancer import ZipEnhancer
|
||||||
Tasks.acoustic_noise_suppression,
|
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||||
model=zipenhancer_model_path)
|
|
||||||
else:
|
else:
|
||||||
self.denoiser = None
|
self.denoiser = None
|
||||||
print("Warm up VoxCPMModel...")
|
print("Warm up VoxCPMModel...")
|
||||||
self.tts_model.generate(
|
self.tts_model.generate(
|
||||||
target_text="Hello, this is the first test sentence."
|
target_text="Hello, this is the first test sentence.",
|
||||||
|
max_len=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -46,17 +45,20 @@ class VoxCPM:
|
|||||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo").
|
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
|
||||||
load_denoiser: Whether to initialize the denoiser pipeline.
|
load_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
||||||
acoustic noise suppression.
|
acoustic noise suppression.
|
||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
Kwargs:
|
||||||
|
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||||
@@ -67,28 +69,34 @@ class VoxCPM:
|
|||||||
``hf_model_id`` is provided.
|
``hf_model_id`` is provided.
|
||||||
"""
|
"""
|
||||||
repo_id = hf_model_id
|
repo_id = hf_model_id
|
||||||
if not repo_id or repo_id.strip() == "":
|
if not repo_id:
|
||||||
raise ValueError("You must provide a valid hf_model_id")
|
raise ValueError("You must provide hf_model_id")
|
||||||
|
|
||||||
local_path = snapshot_download(
|
# Load from local path if provided
|
||||||
repo_id=repo_id,
|
if os.path.isdir(repo_id):
|
||||||
cache_dir=cache_dir,
|
local_path = repo_id
|
||||||
local_files_only=local_files_only,
|
else:
|
||||||
)
|
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||||
|
local_path = snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
voxcpm_model_path=local_path,
|
voxcpm_model_path=local_path,
|
||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _normalize_loudness(self, wav_path: str):
|
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||||
audio, sr = torchaudio.load(wav_path)
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
loudness = torchaudio.functional.loudness(audio, sr)
|
|
||||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
|
||||||
torchaudio.save(wav_path, normalized_audio, sr)
|
|
||||||
|
|
||||||
def generate(self,
|
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
|
def _generate(self,
|
||||||
text : str,
|
text : str,
|
||||||
prompt_wav_path : str = None,
|
prompt_wav_path : str = None,
|
||||||
prompt_text : str = None,
|
prompt_text : str = None,
|
||||||
@@ -100,7 +108,8 @@ class VoxCPM:
|
|||||||
retry_badcase : bool = True,
|
retry_badcase : bool = True,
|
||||||
retry_badcase_max_times : int = 3,
|
retry_badcase_max_times : int = 3,
|
||||||
retry_badcase_ratio_threshold : float = 6.0,
|
retry_badcase_ratio_threshold : float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[np.ndarray, None, None]:
|
||||||
"""Synthesize speech for the given text and return a single waveform.
|
"""Synthesize speech for the given text and return a single waveform.
|
||||||
|
|
||||||
This method optionally builds and reuses a prompt cache. If an external
|
This method optionally builds and reuses a prompt cache. If an external
|
||||||
@@ -122,12 +131,24 @@ class VoxCPM:
|
|||||||
retry_badcase: Whether to retry badcase.
|
retry_badcase: Whether to retry badcase.
|
||||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||||
|
streaming: Whether to return a generator of audio chunks.
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||||
|
Yields audio chunks for each generations step if ``streaming=True``,
|
||||||
|
otherwise yields a single array containing the final audio.
|
||||||
"""
|
"""
|
||||||
texts = text.split("\n")
|
if not text.strip() or not isinstance(text, str):
|
||||||
texts = [t.strip() for t in texts if t.strip()]
|
raise ValueError("target text must be a non-empty string")
|
||||||
final_wav = []
|
|
||||||
|
if prompt_wav_path is not None:
|
||||||
|
if not os.path.exists(prompt_wav_path):
|
||||||
|
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||||
|
|
||||||
|
if (prompt_wav_path is None) != (prompt_text is None):
|
||||||
|
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||||
|
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
temp_prompt_wav_path = None
|
temp_prompt_wav_path = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -135,9 +156,7 @@ class VoxCPM:
|
|||||||
if denoise and self.denoiser is not None:
|
if denoise and self.denoiser is not None:
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||||
temp_prompt_wav_path = tmp_file.name
|
temp_prompt_wav_path = tmp_file.name
|
||||||
|
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||||
self.denoiser(prompt_wav_path, output_path=temp_prompt_wav_path)
|
|
||||||
self._normalize_loudness(temp_prompt_wav_path)
|
|
||||||
prompt_wav_path = temp_prompt_wav_path
|
prompt_wav_path = temp_prompt_wav_path
|
||||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
@@ -146,32 +165,27 @@ class VoxCPM:
|
|||||||
else:
|
else:
|
||||||
fixed_prompt_cache = None # will be built from the first inference
|
fixed_prompt_cache = None # will be built from the first inference
|
||||||
|
|
||||||
for sub_text in texts:
|
if normalize:
|
||||||
if sub_text.strip() == "":
|
if self.text_normalizer is None:
|
||||||
continue
|
from .utils.text_normalize import TextNormalizer
|
||||||
print("sub_text:", sub_text)
|
self.text_normalizer = TextNormalizer()
|
||||||
if normalize:
|
text = self.text_normalizer.normalize(text)
|
||||||
sub_text = self.text_normalizer.normalize(sub_text)
|
|
||||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
|
||||||
target_text=sub_text,
|
|
||||||
prompt_cache=fixed_prompt_cache,
|
|
||||||
min_len=2,
|
|
||||||
max_len=max_length,
|
|
||||||
inference_timesteps=inference_timesteps,
|
|
||||||
cfg_value=cfg_value,
|
|
||||||
retry_badcase=retry_badcase,
|
|
||||||
retry_badcase_max_times=retry_badcase_max_times,
|
|
||||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
|
||||||
)
|
|
||||||
if fixed_prompt_cache is None:
|
|
||||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
|
||||||
original_cache=None,
|
|
||||||
new_text_token=target_text_token,
|
|
||||||
new_audio_feat=generated_audio_feat
|
|
||||||
)
|
|
||||||
final_wav.append(wav)
|
|
||||||
|
|
||||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||||
|
target_text=text,
|
||||||
|
prompt_cache=fixed_prompt_cache,
|
||||||
|
min_len=2,
|
||||||
|
max_len=max_length,
|
||||||
|
inference_timesteps=inference_timesteps,
|
||||||
|
cfg_value=cfg_value,
|
||||||
|
retry_badcase=retry_badcase,
|
||||||
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
for wav, _, _ in generate_result:
|
||||||
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Tuple, Union, Generator, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import warnings
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -84,11 +85,15 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = config.device
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
self.device = "cpu"
|
if torch.backends.mps.is_available():
|
||||||
|
self.device = "mps"
|
||||||
|
else:
|
||||||
|
self.device = "cpu"
|
||||||
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
self.base_lm = MiniCPMModel(config.lm_config)
|
self.base_lm = MiniCPMModel(config.lm_config)
|
||||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||||
|
|
||||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||||
self.audio_start_token = 101
|
self.audio_start_token = 101
|
||||||
@@ -99,7 +104,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||||
residual_lm_config.vocab_size = 0
|
residual_lm_config.vocab_size = 0
|
||||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||||
|
|
||||||
# Local Encoder
|
# Local Encoder
|
||||||
encoder_config = config.lm_config.model_copy(deep=True)
|
encoder_config = config.lm_config.model_copy(deep=True)
|
||||||
@@ -147,13 +152,23 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.sample_rate = audio_vae.sample_rate
|
self.sample_rate = audio_vae.sample_rate
|
||||||
|
|
||||||
|
|
||||||
def optimize(self):
|
def optimize(self, disable: bool = False):
|
||||||
if self.device == "cuda":
|
try:
|
||||||
|
if disable:
|
||||||
|
raise ValueError("Optimization disabled by user")
|
||||||
|
if self.device != "cuda":
|
||||||
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
except:
|
||||||
|
raise ValueError("triton is not installed")
|
||||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
else:
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||||
self.base_lm.forward_step = self.base_lm.forward_step
|
self.base_lm.forward_step = self.base_lm.forward_step
|
||||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||||
self.feat_encoder_step = self.feat_encoder
|
self.feat_encoder_step = self.feat_encoder
|
||||||
@@ -161,8 +176,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_text: str = "",
|
prompt_text: str = "",
|
||||||
@@ -174,7 +195,11 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
if len(prompt_wav_path) == 0:
|
if len(prompt_wav_path) == 0:
|
||||||
text = target_text
|
text = target_text
|
||||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||||
@@ -250,14 +275,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
text_token = text_token.unsqueeze(0).to(self.device)
|
text_token = text_token.unsqueeze(0).to(self.device)
|
||||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
target_text_length = len(self.text_tokenizer(target_text))
|
target_text_length = len(self.text_tokenizer(target_text))
|
||||||
|
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -266,17 +291,31 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
if retry_badcase:
|
if streaming:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
patch_len = self.patch_size * self.chunk_size
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
for latent_pred, _ in inference_result:
|
||||||
retry_badcase_times += 1
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
continue
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
|
yield decode_audio
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
|
if retry_badcase:
|
||||||
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
|
retry_badcase_times += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
break
|
if not streaming:
|
||||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
|
yield decode_audio
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def build_prompt_cache(
|
def build_prompt_cache(
|
||||||
@@ -314,7 +353,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||||
|
|
||||||
# extract audio features
|
# extract audio features
|
||||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||||
|
|
||||||
audio_feat = audio_feat.view(
|
audio_feat = audio_feat.view(
|
||||||
self.audio_vae.latent_dim,
|
self.audio_vae.latent_dim,
|
||||||
@@ -366,8 +405,16 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_with_prompt_cache_streaming(
|
||||||
|
self, *args, **kwargs
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_with_prompt_cache(
|
def _generate_with_prompt_cache(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_cache: dict,
|
prompt_cache: dict,
|
||||||
@@ -378,7 +425,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0,
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate audio using pre-built prompt cache.
|
Generate audio using pre-built prompt cache.
|
||||||
|
|
||||||
@@ -392,10 +440,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: Whether to retry on bad cases
|
retry_badcase: Whether to retry on bad cases
|
||||||
retry_badcase_max_times: Maximum retry attempts
|
retry_badcase_max_times: Maximum retry attempts
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||||
|
streaming: Whether to return a generator of audio chunks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
Generator of Tuple containing:
|
||||||
|
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||||
|
- Tensor of new text tokens
|
||||||
|
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
# get prompt from cache
|
# get prompt from cache
|
||||||
if prompt_cache is None:
|
if prompt_cache is None:
|
||||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||||
@@ -433,14 +488,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
text_token = text_token.unsqueeze(0).to(self.device)
|
text_token = text_token.unsqueeze(0).to(self.device)
|
||||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# run inference
|
# run inference
|
||||||
target_text_length = len(self.text_tokenizer(target_text))
|
target_text_length = len(self.text_tokenizer(target_text))
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -449,26 +504,48 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
if retry_badcase:
|
if streaming:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
patch_len = self.patch_size * self.chunk_size
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
for latent_pred, pred_audio_feat in inference_result:
|
||||||
retry_badcase_times += 1
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
continue
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
|
yield (
|
||||||
|
decode_audio,
|
||||||
|
target_text_token,
|
||||||
|
pred_audio_feat
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
|
if retry_badcase:
|
||||||
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
|
retry_badcase_times += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
if not streaming:
|
||||||
break
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
|
|
||||||
return (
|
yield (
|
||||||
decode_audio,
|
decode_audio,
|
||||||
target_text_token,
|
target_text_token,
|
||||||
pred_audio_feat
|
pred_audio_feat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._inference(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def _inference(
|
||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
text_mask: torch.Tensor,
|
text_mask: torch.Tensor,
|
||||||
@@ -478,7 +555,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: int = 2000,
|
max_len: int = 2000,
|
||||||
inference_timesteps: int = 10,
|
inference_timesteps: int = 10,
|
||||||
cfg_value: float = 2.0,
|
cfg_value: float = 2.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""Core inference method for audio generation.
|
"""Core inference method for audio generation.
|
||||||
|
|
||||||
This is the main inference loop that generates audio features
|
This is the main inference loop that generates audio features
|
||||||
@@ -493,11 +571,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: Maximum generation length
|
max_len: Maximum generation length
|
||||||
inference_timesteps: Number of diffusion steps
|
inference_timesteps: Number of diffusion steps
|
||||||
cfg_value: Classifier-free guidance value
|
cfg_value: Classifier-free guidance value
|
||||||
|
streaming: Whether to yield each step latent feature or just the final result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Generator of Tuple containing:
|
||||||
- Predicted latent features
|
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||||
- Predicted audio feature sequence
|
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
B, T, P, D = feat.shape
|
B, T, P, D = feat.shape
|
||||||
|
|
||||||
@@ -555,6 +634,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||||
|
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||||
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
break
|
break
|
||||||
@@ -569,14 +654,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
if not streaming:
|
||||||
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
|
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str):
|
def from_local(cls, path: str, optimize: bool = True):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
@@ -589,7 +674,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
)["state_dict"]
|
)["state_dict"]
|
||||||
|
|
||||||
model = cls(config, tokenizer, audio_vae)
|
model = cls(config, tokenizer, audio_vae)
|
||||||
lm_dtype = get_dtype(config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||||
|
|
||||||
@@ -602,4 +687,4 @@ class VoxCPMModel(nn.Module):
|
|||||||
for kw, val in vae_state_dict.items():
|
for kw, val in vae_state_dict.items():
|
||||||
model_state_dict[f"audio_vae.{kw}"] = val
|
model_state_dict[f"audio_vae.{kw}"] = val
|
||||||
model.load_state_dict(model_state_dict, strict=True)
|
model.load_state_dict(model_state_dict, strict=True)
|
||||||
return model.to(model.device).eval().optimize()
|
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class UnifiedCFM(torch.nn.Module):
|
|||||||
shape: (n_timesteps + 1,)
|
shape: (n_timesteps + 1,)
|
||||||
mu (torch.Tensor): output of encoder
|
mu (torch.Tensor): output of encoder
|
||||||
shape: (batch_size, n_feats)
|
shape: (batch_size, n_feats)
|
||||||
cond: Not used but kept for future purposes
|
cond: condition -- prefix prompt
|
||||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||||
|
|||||||
@@ -154,6 +154,11 @@ class MiniCPMAttention(nn.Module):
|
|||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
|||||||
|
|
||||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||||
|
|
||||||
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_cache = key_cache.contiguous()
|
||||||
|
value_cache = value_cache.contiguous()
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_cache,
|
key_cache,
|
||||||
|
|||||||
@@ -3,40 +3,7 @@ import re
|
|||||||
import regex
|
import regex
|
||||||
import inflect
|
import inflect
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
from wetext import Normalizer
|
||||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
|
||||||
|
|
||||||
def normal_cut_sentence(text):
|
|
||||||
# 先处理括号内的逗号,将其替换为特殊标记
|
|
||||||
text = re.sub(r'([((][^))]*)([,,])([^))]*[))])', r'\1&&&\3', text)
|
|
||||||
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
|
|
||||||
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
|
|
||||||
text = re.sub('(\…{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
|
|
||||||
text = re.sub('([. ,。!;?\?\.{6}\…{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
|
|
||||||
# 处理英文句子的分隔
|
|
||||||
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
|
|
||||||
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
|
|
||||||
text = re.sub(r'([((][^))]*)(&&&)([^))]*[))])', r'\1,\3', text)
|
|
||||||
text = [t for t in text.split("\n") if t]
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def cut_sentence_with_fix_length(text : str, length : int):
|
|
||||||
sentences = normal_cut_sentence(text)
|
|
||||||
cur_length = 0
|
|
||||||
res = ""
|
|
||||||
for sentence in sentences:
|
|
||||||
if not sentence:
|
|
||||||
continue
|
|
||||||
if cur_length > length or cur_length + len(sentence) > length:
|
|
||||||
yield res
|
|
||||||
res = ""
|
|
||||||
cur_length = 0
|
|
||||||
res += sentence
|
|
||||||
cur_length += len(sentence)
|
|
||||||
if res:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
|
|
||||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||||
|
|
||||||
@@ -195,8 +162,8 @@ def clean_text(text):
|
|||||||
class TextNormalizer:
|
class TextNormalizer:
|
||||||
def __init__(self, tokenizer=None):
|
def __init__(self, tokenizer=None):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True)
|
self.zh_tn_model = Normalizer(lang="zh", operator="tn", remove_erhua=True)
|
||||||
self.en_tn_model = EnNormalizer()
|
self.en_tn_model = Normalizer(lang="en", operator="tn")
|
||||||
self.inflect_parser = inflect.engine()
|
self.inflect_parser = inflect.engine()
|
||||||
|
|
||||||
def normalize(self, text, split=False):
|
def normalize(self, text, split=False):
|
||||||
@@ -207,38 +174,12 @@ class TextNormalizer:
|
|||||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||||
text = self.zh_tn_model.normalize(text)
|
|
||||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
|
||||||
text = self.zh_tn_model.normalize(text)
|
text = self.zh_tn_model.normalize(text)
|
||||||
text = replace_blank(text)
|
text = replace_blank(text)
|
||||||
text = replace_corner_mark(text)
|
text = replace_corner_mark(text)
|
||||||
text = remove_bracket(text)
|
text = remove_bracket(text)
|
||||||
text = re.sub(r'[,,]+$', '。', text)
|
|
||||||
else:
|
else:
|
||||||
text = self.en_tn_model.normalize(text)
|
text = self.en_tn_model.normalize(text)
|
||||||
text = spell_out_number(text, self.inflect_parser)
|
text = spell_out_number(text, self.inflect_parser)
|
||||||
if split is False:
|
if split is False:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
text_normalizer = TextNormalizer()
|
|
||||||
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
|
|
||||||
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
|
|
||||||
其中,aaa、bbb 和 ccc 是常数,xxx 是变量。这个方程的解可以通过求根公式来找到。
|
|
||||||
一元二次方程的解法有几种:
|
|
||||||
- 因式分解法:通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式,解决方程的解。比如,方程x2−5x+6=0x^2 - 5x + 6 = 0x2−5x+6=0可以因式分解为(x−2)(x−3)=0(x - 2)(x - 3) = 0(x−2)(x−3)=0,因此根为2和3。
|
|
||||||
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
|
|
||||||
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
|
|
||||||
公式:x=−b±b2−4ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2a−b±b2−4ac这个公式可以帮助我们求解任何一元二次方程的根。
|
|
||||||
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b2−4ac\Delta = b^2 - 4acΔ=b2−4ac 根据判别式的值,我们可以知道:
|
|
||||||
- 如果 Δ>0\Delta > 0Δ>0,方程有两个不相等的实数解。这是因为判别式大于0时,根号内的值是正数,所以我们可以得到两个不同的解。
|
|
||||||
- 如果 Δ=0\Delta = 0Δ=0,方程有一个实数解。这是因为根号内的值为零,导致两个解相等,也就是说方程有一个解。
|
|
||||||
- 如果 Δ<0\Delta < 0Δ<0,方程没有实数解。这意味着根号内的值是负数,无法进行实数运算,因此方程没有实数解,可能有复数解。"""
|
|
||||||
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战,", "550 + 320 = 870千卡。", "解一元二次方程:3x^2+x-2=0", "你好啊"]
|
|
||||||
texts = [text]
|
|
||||||
for text in texts:
|
|
||||||
text = text_normalizer.normalize(text)
|
|
||||||
print(text)
|
|
||||||
for t in cut_sentence_with_fix_length(text, 15):
|
|
||||||
print(t)
|
|
||||||
76
src/voxcpm/zipenhancer.py
Normal file
76
src/voxcpm/zipenhancer.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""
|
||||||
|
ZipEnhancer Module - Audio Denoising Enhancer
|
||||||
|
|
||||||
|
Provides on-demand import ZipEnhancer functionality for audio denoising processing.
|
||||||
|
Related dependencies are imported only when denoising functionality is needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from typing import Optional, Union
|
||||||
|
import torchaudio
|
||||||
|
import torch
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
|
class ZipEnhancer:
|
||||||
|
"""ZipEnhancer Audio Denoising Enhancer"""
|
||||||
|
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
||||||
|
"""
|
||||||
|
Initialize ZipEnhancer
|
||||||
|
Args:
|
||||||
|
model_path: ModelScope model path or local path
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self._pipeline = pipeline(
|
||||||
|
Tasks.acoustic_noise_suppression,
|
||||||
|
model=self.model_path
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_loudness(self, wav_path: str):
|
||||||
|
"""
|
||||||
|
Audio loudness normalization
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_path: Audio file path
|
||||||
|
"""
|
||||||
|
audio, sr = torchaudio.load(wav_path)
|
||||||
|
loudness = torchaudio.functional.loudness(audio, sr)
|
||||||
|
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
||||||
|
torchaudio.save(wav_path, normalized_audio, sr)
|
||||||
|
|
||||||
|
def enhance(self, input_path: str, output_path: Optional[str] = None,
|
||||||
|
normalize_loudness: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Audio denoising enhancement
|
||||||
|
Args:
|
||||||
|
input_path: Input audio file path
|
||||||
|
output_path: Output audio file path (optional, creates temp file by default)
|
||||||
|
normalize_loudness: Whether to perform loudness normalization
|
||||||
|
Returns:
|
||||||
|
str: Output audio file path
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If pipeline is not initialized or processing fails
|
||||||
|
"""
|
||||||
|
if not os.path.exists(input_path):
|
||||||
|
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
||||||
|
# Create temporary file if no output path is specified
|
||||||
|
if output_path is None:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||||
|
output_path = tmp_file.name
|
||||||
|
try:
|
||||||
|
# Perform denoising processing
|
||||||
|
self._pipeline(input_path, output_path=output_path)
|
||||||
|
# Loudness normalization
|
||||||
|
if normalize_loudness:
|
||||||
|
self._normalize_loudness(output_path)
|
||||||
|
return output_path
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up possibly created temporary files
|
||||||
|
if output_path and os.path.exists(output_path):
|
||||||
|
try:
|
||||||
|
os.unlink(output_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise RuntimeError(f"Audio denoising processing failed: {e}")
|
||||||
Reference in New Issue
Block a user