mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 03:48:12 +00:00
init
This commit is contained in:
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright OpenBMB
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
268
README.md
Normal file
268
README.md
Normal file
@@ -0,0 +1,268 @@
|
||||
## 🎙️ VoxCPM: Tokenizer-Free TTS for Context-Aware Speech Generation and True-to-Life Voice Cloning
|
||||
|
||||
|
||||
[](https://github.com/OpenBMB/VoxCPM/) [](hhttps://huggingface.co/openbmb/VoxCPM-0.5B) [](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) [](https://thuhcsi.github.io/VoxCPM/)
|
||||
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/voxcpm_logo.png" alt="VoxCPM Logo" width="40%">
|
||||
</div>
|
||||
|
||||
## News
|
||||
* [2025.09.16] 🔥 🔥 🔥 We Open Source the VoxCPM-0.5B weights!
|
||||
* [2025.09.16] 🎉 🎉 🎉 We Provide the [Gradio PlayGround](https://huggingface.co/spaces/OpenBMB/VoxCPM-Demo) for VoxCPM-0.5B, try it now!
|
||||
|
||||
## Overview
|
||||
|
||||
VoxCPM is a novel tokenizer-free Text-to-Speech (TTS) system that redefines realism in speech synthesis. By modeling speech in a continuous space, it overcomes the limitations of discrete tokenization and enables two flagship capabilities: context-aware speech generation and true-to-life zero-shot voice cloning.
|
||||
|
||||
Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses an end-to-end diffusion autoregressive architecture that directly generates continuous speech representations from text. Built on [MiniCPM-4](https://huggingface.co/openbmb/MiniCPM4-0.5B), it achieves implicit semantic-acoustic decoupling through hierachical language modeling and FSQ constraints, greatly enhancing both expressiveness and generation stability.
|
||||
|
||||
<div align="center">
|
||||
<img src="assets/voxcpm_model.png" alt="VoxCPM Model Architecture" width="500">
|
||||
</div>
|
||||
|
||||
|
||||
### 🚀 Key Features
|
||||
- **Context-Aware, Expressive Speech Generation** - VoxCPM comprehends text to infer and generate appropriate prosody, delivering speech with remarkable expressiveness and natural flow. It spontaneously adapts speaking style based on content, producing highly fitting vocal expression trained on a massive 1.8 million-hour bilingual corpus.
|
||||
- **True-to-Life Voice Cloning** - With only a short reference audio clip, VoxCPM performs accurate zero-shot voice cloning, capturing not only the speaker’s timbre but also fine-grained characteristics such as accent, emotional tone, rhythm, and pacing to create a faithful and natural replica.
|
||||
- **High-Efficiency Synthesis** - VoxCPM supports streaming synthesis with a Real-Time Factor (RTF) as low as 0.17 on a consumer-grade NVIDIA RTX 4090 GPU, making it possible for real-time applications.
|
||||
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 🔧 Install from PyPI
|
||||
``` sh
|
||||
pip install voxcpm
|
||||
```
|
||||
### 1. Model Download (Optional)
|
||||
By default, when you first run the script, the model will be downloaded automatically, but you can also download the model in advance.
|
||||
- Download VoxCPM-0.5B
|
||||
```
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
|
||||
```
|
||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||
```
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('iic/speech_zipenhancer_ans_multiloss_16k_base')
|
||||
snapshot_download('iic/SenseVoiceSmall')
|
||||
```
|
||||
|
||||
### 2. Basic Usage
|
||||
```python
|
||||
import soundfile as sf
|
||||
from voxcpm import VoxCPM
|
||||
|
||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
||||
|
||||
wav = model.generate(
|
||||
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
||||
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
||||
prompt_text=None, # optional: reference text
|
||||
cfg_value=2.0,
|
||||
inference_timesteps=10,
|
||||
normalize=True,
|
||||
denoise=True,
|
||||
retry_badcase=True, # optional: enable retrying mode
|
||||
retry_badcase_max_times=3,
|
||||
retry_badcase_ratio_threshold=6.0,
|
||||
)
|
||||
|
||||
sf.write("output.wav", wav, 16000)
|
||||
print("saved: output.wav")
|
||||
```
|
||||
|
||||
### 3. CLI Usage
|
||||
|
||||
After installation, the entry point is `voxcpm` (or use `python -m voxcpm.cli`).
|
||||
|
||||
```bash
|
||||
# 1) Direct synthesis (single text)
|
||||
voxcpm --text "Hello VoxCPM" --output out.wav
|
||||
|
||||
# 2) Voice cloning (reference audio + transcript)
|
||||
voxcpm --text "Hello" \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "reference transcript" \
|
||||
--output out.wav \
|
||||
--denoise
|
||||
|
||||
# 3) Batch processing (one text per line)
|
||||
voxcpm --input examples/input.txt --output-dir outs
|
||||
# (optional) Batch + cloning
|
||||
voxcpm --input examples/input.txt --output-dir outs \
|
||||
--prompt-audio path/to/voice.wav \
|
||||
--prompt-text "reference transcript" \
|
||||
--denoise
|
||||
|
||||
# 4) Inference parameters (quality/speed)
|
||||
voxcpm --text "..." --output out.wav \
|
||||
--cfg-value 2.0 --inference-timesteps 10 --normalize
|
||||
|
||||
# 5) Model loading
|
||||
# Prefer local path
|
||||
voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir
|
||||
# Or from Hugging Face (auto download/cache)
|
||||
voxcpm --text "..." --output out.wav \
|
||||
--hf-model-id openbmb/VoxCPM-0.5B --cache-dir ~/.cache/huggingface --local-files-only
|
||||
|
||||
# 6) Denoiser control
|
||||
voxcpm --text "..." --output out.wav \
|
||||
--no-denoiser --zipenhancer-path iic/speech_zipenhancer_ans_multiloss_16k_base
|
||||
|
||||
# 7) Help
|
||||
voxcpm --help
|
||||
python -m voxcpm.cli --help
|
||||
```
|
||||
|
||||
### 4. Start web demo
|
||||
|
||||
You can start the UI interface by running `python app.py`, which allows you to perform Voice Cloning and Voice Creation.
|
||||
|
||||
|
||||
|
||||
## 👩🍳 A Voice Chef's Guide
|
||||
Welcome to the VoxCPM kitchen! Follow this recipe to cook up perfect generated speech. Let’s begin.
|
||||
|
||||
---
|
||||
### 🥚 Step 1: Prepare Your Base Ingredients (Content)
|
||||
|
||||
First, choose how you’d like to input your text:.
|
||||
1. Regular Text (Classic Mode)
|
||||
- ✅ Keep "Text Normalization" ON. Type naturally (e.g., "Hello, world! 123"). The system will automatically process numbers, abbreviations, and punctuation using WeTextProcessing library.
|
||||
2. Phoneme Input (Native Mode)
|
||||
- ❌ Turn "Text Normalization" OFF. Enter phoneme text like {HH AH0 L OW1} (EN) or {ni3}{hao3} (ZH) for precise pronunciation control. In this mode, VoxCPM also supports native understanding of other complex non-normalized text—try it out!
|
||||
|
||||
---
|
||||
### 🍳 Step 2: Choose Your Flavor Profile (Voice Style)
|
||||
|
||||
This is the secret sauce that gives your audio its unique sound.
|
||||
1. Cooking with a Prompt Speech (Following a Famous Recipe)
|
||||
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
|
||||
- For a Clean, Studio-Quality Voice:
|
||||
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
|
||||
2. Cooking au Naturel (Letting the Model Improvise)
|
||||
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.
|
||||
- Pro Tip: Challenge VoxCPM with any text—poetry, song lyrics, dramatic monologues—it may deliver some interesting results!
|
||||
|
||||
---
|
||||
### 🧂 Step 3: The Final Seasoning (Fine-Tuning Your Results)
|
||||
You're ready to serve! But for master chefs who want to tweak the flavor, here are two key spices.
|
||||
- CFG Value (How Closely to Follow the Recipe)
|
||||
- Default: A great starting point.
|
||||
- Voice sounds strained or weird? Lower this value. It tells the model to be more relaxed and improvisational, great for expressive prompts.
|
||||
- Need maximum clarity and adherence to the text? Raise it slightly to keep the model on a tighter leash.
|
||||
- Inference Timesteps (Simmering Time: Quality vs. Speed)
|
||||
- Need a quick snack? Use a lower number. Perfect for fast drafts and experiments.
|
||||
- Cooking a gourmet meal? Use a higher number. This lets the model "simmer" longer, refining the audio for superior detail and naturalness.
|
||||
|
||||
---
|
||||
Happy creating! 🎉 Start with the default settings and tweak from there to suit your project. The kitchen is yours!
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
## 📊 Performance Highlights
|
||||
|
||||
VoxCPM achieves competitive results on public zero-shot TTS benchmarks:
|
||||
|
||||
### Seed-TTS-eval Benchmark
|
||||
|
||||
| Model | Parameters | Open-Source | test-EN | | test-ZH | | test-Hard | |
|
||||
|------|------|------|:------------:|:--:|:------------:|:--:|:-------------:|:--:|
|
||||
| | | | WER/%⬇ | SIM/%⬆| CER/%⬇| SIM/%⬆ | CER/%⬇ | SIM/%⬆ |
|
||||
| CosyVoice | 0.3B | ✅ | 4.29 | 60.9 | 3.63 | 72.3 | 11.75 | 70.9 |
|
||||
| CosyVoice2 | 0.5B | ✅ | 3.09 | 65.9 | 1.38 | 75.7 | 6.83 | 72.4 |
|
||||
| F5-TTS | 0.3B | ✅ | 2.00 | 67.0 | 1.53 | 76.0 | 8.67 | 71.3 |
|
||||
| SparkTTS | 0.5B | ✅ | 3.14 | 57.3 | 1.54 | 66.0 | - | - |
|
||||
| FireRedTTS | 0.5B | ✅ | 3.82 | 46.0 | 1.51 | 63.5 | 17.45 | 62.1 |
|
||||
| FireRedTTS-2 | 1.5B | ✅ | 1.95 | 66.5 | 1.14 | 73.6 | - | - |
|
||||
| Qwen2.5-Omni | 7B | ✅ | 2.72 | 63.2 | 1.70 | 75.2 | 7.97 | 74.7 |
|
||||
| OpenAudio-s1-mini | 0.5B | ✅ | 1.94 | 55.0 | 1.18 | 68.5 | - | - |
|
||||
| IndexTTS2 | 1.5B | ✅ | 2.23 | 70.6 | 1.03 | 76.5 | - | - |
|
||||
| VibeVoice | 1.5B | ✅ | 3.04 | 68.9 | 1.16 | 74.4 | - | - |
|
||||
| HiggsAudio-v2 | 3B | ✅ | 2.44 | 67.7 | 1.50 | 74.0 | - | - |
|
||||
| CosyVoice3 | 0.5B | ❌ | 2.02 | 71.8 | 1.16 | 78.0 | 6.08 | 75.8 |
|
||||
| CosyVoice3 | 1.5B | ❌ | 2.22 | 72.0 | 1.12 | 78.1 | 5.83 | 75.8 |
|
||||
| MegaTTS3 | 0.5B | ❌ | 2.79 | 77.1 | 1.52 | 79.0 | - | - |
|
||||
| DiTAR | 0.6B | ❌ | 1.69 | 73.5 | 1.02 | 75.3 | - | - |
|
||||
| Seed-TTS | - | ❌ | 2.25 | 76.2 | 1.12 | 79.6 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | - | ❌ | 1.65 | 69.2 | 0.83 | 78.3 | - | - |
|
||||
| **VoxCPM** | **0.5B** | **✅** | **1.85** | **72.9** | **0.93** | **77.2** | 8.87 | 73.0 |
|
||||
|
||||
|
||||
### CV3-eval Benchmark
|
||||
|
||||
| Model | zh | en | hard-zh | | | hard-en | | | |
|
||||
|-------|:--:|:--:|:-------:|:--:|:--:|:-------:|:--:|:--:|:--:|
|
||||
| | CER/%⬇ | WER/%⬇ | CER/%⬇ | SIM/%⬆ | DNSMOS⬆ | WER/%⬇ | SIM/%⬆ | DNSMOS⬆ | |
|
||||
| F5-TTS | 5.47 | 8.90 | - | - | - | - | - | - | |
|
||||
| SparkTTS | 5.15 | 11.0 | - | - | - | - | - | - | |
|
||||
| GPT-SoVits | 7.34 | 12.5 | - | - | - | - | - | - | |
|
||||
| CosyVoice2 | 4.08 | 6.32 | 12.58 | 72.6 | 3.81 | 11.96 | 66.7 | 3.95 | |
|
||||
| OpenAudio-s1-mini | 4.00 | 5.54 | 18.1 | 58.2 | 3.77 | 12.4 | 55.7 | 3.89 | |
|
||||
| IndexTTS2 | 3.58 | 4.45 | 12.8 | 74.6 | 3.65 | fail | fail | fail | |
|
||||
| HiggsAudio-v2 | 9.54 | 7.89 | 41.0 | 60.2 | 3.39 | 10.3 | 61.8 | 3.68 | |
|
||||
| CosyVoice3-0.5B | 3.89 | 5.24 | 14.15 | 78.6 | 3.75 | 9.04 | 75.9 | 3.92 | |
|
||||
| CosyVoice3-1.5B | 3.91 | 4.99 | 9.77 | 78.5 | 3.79 | 10.55 | 76.1 | 3.95 | |
|
||||
| **VoxCPM** | **3.40** | **4.04** | 12.9 | 66.1 | 3.59 | **7.89** | 64.3 | 3.74 | |
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## ⚠️ Risks and limitations
|
||||
- General Model Behavior: While VoxCPM has been trained on a large-scale dataset, it may still produce outputs that are unexpected, biased, or contain artifacts.
|
||||
- Potential for Misuse of Voice Cloning: VoxCPM's powerful zero-shot voice cloning capability can generate highly realistic synthetic speech. This technology could be misused for creating convincing deepfakes for purposes of impersonation, fraud, or spreading disinformation. Users of this model must not use it to create content that infringes upon the rights of individuals. It is strictly forbidden to use VoxCPM for any illegal or unethical purposes. We strongly recommend that any publicly shared content generated with this model be clearly marked as AI-generated.
|
||||
- Current Technical Limitations: Although generally stable, the model may occasionally exhibit instability, especially with very long or expressive inputs. Furthermore, the current version offers limited direct control over specific speech attributes like emotion or speaking style.
|
||||
- Bilingual Model: VoxCPM is trained primarily on Chinese and English data. Performance on other languages is not guaranteed and may result in unpredictable or low-quality audio.
|
||||
- This model is released for research and development purposes only. We do not recommend its use in production or commercial applications without rigorous testing and safety evaluations. Please use VoxCPM responsibly.
|
||||
|
||||
|
||||
|
||||
## 📄 License
|
||||
The VoxCPM model weights and code are open-sourced under the [Apache-2.0](LICENSE) license.
|
||||
|
||||
## 🙏 Acknowledgments
|
||||
|
||||
We extend our sincere gratitude to the following works and resources for their inspiration and contributions:
|
||||
|
||||
- [DiTAR](https://arxiv.org/abs/2502.03930) for the diffusion autoregressive backbone used in speech generation
|
||||
- [MiniCPM-4](https://github.com/OpenBMB/MiniCPM) for serving as the language model foundation
|
||||
- [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for the implementation of Flow Matching-based LocDiT
|
||||
- [DAC](https://github.com/descriptinc/descript-audio-codec) for providing the Audio VAE backbone
|
||||
|
||||
## Institutions
|
||||
|
||||
This project is developed by the following institutions:
|
||||
- <img src="assets/modelbest_logo.png" width="28px"> [ModelBest](https://modelbest.cn/)
|
||||
|
||||
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
||||
|
||||
|
||||
|
||||
|
||||
## 📚 Citation
|
||||
|
||||
If you find our model helpful, please consider citing our projects 📝 and staring us ⭐️!
|
||||
|
||||
```bib
|
||||
@misc{voxcpm2025,
|
||||
author = {{Yixuan Zhou, Guoyang Zeng, Xin Liu, Xiang Li, Renjie Yu, Ziyang Wang, Runchuan Ye, Weiyue Sun, Jiancheng Gui, Kehan Li, Zhiyong Wu, Zhiyuan Liu}},
|
||||
title = {{VoxCPM}},
|
||||
year = {2025},
|
||||
publish = {\url{https://github.com/OpenBMB/VoxCPM}},
|
||||
note = {GitHub repository}
|
||||
}
|
||||
```
|
||||
279
app.py
Normal file
279
app.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
import spaces
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM-0.5B"
|
||||
|
||||
import voxcpm
|
||||
|
||||
|
||||
class VoxCPMDemo:
|
||||
def __init__(self) -> None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"🚀 Running on device: {self.device}")
|
||||
|
||||
# ASR model for prompt text recognition
|
||||
self.asr_model_id = "iic/SenseVoiceSmall"
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM-0.5B"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
"""
|
||||
Resolve model directory:
|
||||
1) Use local checkpoint directory if exists
|
||||
2) If HF_REPO_ID env is set, download into models/{repo}
|
||||
3) Fallback to 'models'
|
||||
"""
|
||||
if os.path.isdir(self.default_local_model_dir):
|
||||
return self.default_local_model_dir
|
||||
|
||||
repo_id = os.environ.get("HF_REPO_ID", "").strip()
|
||||
if len(repo_id) > 0:
|
||||
target_dir = os.path.join("models", repo_id.replace("/", "__"))
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...")
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
except Exception as e:
|
||||
print(f"Warning: HF download failed: {e}. Falling back to 'data'.")
|
||||
return "models"
|
||||
return target_dir
|
||||
return "models"
|
||||
|
||||
def get_or_load_voxcpm(self) -> voxcpm.VoxCPM:
|
||||
if self.voxcpm_model is not None:
|
||||
return self.voxcpm_model
|
||||
print("Model not loaded, initializing...")
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}")
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
print("Model loaded successfully.")
|
||||
return self.voxcpm_model
|
||||
|
||||
# ---------- Functional endpoints ----------
|
||||
def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
|
||||
text = (text_input or "").strip()
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'")
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
denoise=denoise,
|
||||
)
|
||||
return (16000, wav)
|
||||
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
# static assets (logo path)
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
||||
),
|
||||
css="""
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
) as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
gr.Markdown("""
|
||||
### How to Use |使用说明
|
||||
1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis.
|
||||
**(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征
|
||||
2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available).
|
||||
**(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。
|
||||
3. **Enter target text** - Type the text you want the model to speak.
|
||||
**输入目标文本** - 输入您希望模型朗读的文字内容。
|
||||
4. **Generate Speech** - Click the "Generate" button to create your audio.
|
||||
**生成语音** - 点击"生成"按钮,即可为您创造出音频。
|
||||
""")
|
||||
|
||||
# Pro Tips
|
||||
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
|
||||
gr.Markdown(f"""
|
||||
### Prompt Speech Enhancement|参考语音降噪
|
||||
- **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
|
||||
**启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质。
|
||||
- **Disable** to preserve the original audio's background atmosphere.
|
||||
**禁用**:保留原始音频的背景环境声,如果想复刻相应声学环境。
|
||||
|
||||
### Text Normalization|文本正则化
|
||||
- **Enable** to process general text with an external WeTextProcessing component.
|
||||
**启用**:使用 WeTextProcessing 组件,可处理常见文本。
|
||||
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input ({HH AH0 L OW1}), try it!
|
||||
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如 {da4}{jia1}好)和公式符号合成,尝试一下!
|
||||
|
||||
### CFG Value|CFG 值
|
||||
- **Lower CFG** if the voice prompt sounds strained or expressive.
|
||||
**调低**:如果提示语音听起来不自然或过于夸张。
|
||||
- **Higher CFG** for better adherence to the prompt speech style or input text.
|
||||
**调高**:为更好地贴合提示音频的风格或输入文本。
|
||||
|
||||
### Inference Timesteps|推理时间步
|
||||
- **Lower** for faster synthesis speed.
|
||||
**调低**:合成速度更快。
|
||||
- **Higher** for better synthesis quality.
|
||||
**调高**:合成质量更佳。
|
||||
|
||||
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
|
||||
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
|
||||
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
|
||||
""")
|
||||
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
type="filepath",
|
||||
label="Prompt Speech",
|
||||
value="./examples/example.wav",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
with gr.Column():
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use WeTextPorcessing library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
# Recommended to enable queue on Spaces for better throughput
|
||||
interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_demo()
|
||||
BIN
assets/modelbest_logo.png
Normal file
BIN
assets/modelbest_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 54 KiB |
BIN
assets/thuhcsi_logo.png
Normal file
BIN
assets/thuhcsi_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.9 KiB |
BIN
assets/voxcpm_logo.png
Normal file
BIN
assets/voxcpm_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 24 KiB |
BIN
assets/voxcpm_model.png
Normal file
BIN
assets/voxcpm_model.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 142 KiB |
BIN
examples/example.wav
Normal file
BIN
examples/example.wav
Normal file
Binary file not shown.
96
pyproject.toml
Normal file
96
pyproject.toml
Normal file
@@ -0,0 +1,96 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "voxcpm"
|
||||
dynamic = ["version"]
|
||||
description = "A Python package for VoxCPM"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
authors = [
|
||||
{name = "OpenBMB", email = "openbmb@gmail.com"}
|
||||
]
|
||||
maintainers = [
|
||||
{name = "OpenBMB", email = "openbmb@gmail.com"}
|
||||
]
|
||||
keywords = ["voxcpm", "text-to-speech", "tts", "speech-synthesis", "voice-cloning", "ai", "deep-learning", "pytorch"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
]
|
||||
requires-python = ">=3.8"
|
||||
dependencies = [
|
||||
"torch==2.5.1",
|
||||
"torchaudio==2.5.1",
|
||||
"transformers==4.50.1",
|
||||
"einops",
|
||||
"gradio",
|
||||
"inflect",
|
||||
"WeTextProcessing",
|
||||
"addict",
|
||||
"modelscope==1.22.0",
|
||||
"simplejson",
|
||||
"datasets==2.18.0",
|
||||
"sortedcontainers",
|
||||
"librosa",
|
||||
"huggingface-hub",
|
||||
"pydantic",
|
||||
"tqdm",
|
||||
"soundfile",
|
||||
"funasr",
|
||||
"spaces"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=6.0",
|
||||
"pytest-cov>=2.0",
|
||||
"black>=21.0",
|
||||
"flake8>=3.8",
|
||||
"mypy>=0.800",
|
||||
"pre-commit>=2.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
voxcpm = "voxcpm.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/OpenBMB/VoxCPM"
|
||||
Repository = "https://github.com/OpenBMB/VoxCPM.git"
|
||||
Documentation = "https://github.com/OpenBMB/VoxCPM#readme"
|
||||
"Bug Tracker" = "https://github.com/OpenBMB/VoxCPM/issues"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["voxcpm"]
|
||||
|
||||
[tool.setuptools.package-dir]
|
||||
"" = "src"
|
||||
|
||||
[tool.setuptools_scm]
|
||||
version_scheme = "post-release"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py38']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
16
requirements.txt
Normal file
16
requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
torch==2.5.1
|
||||
torchaudio==2.5.1
|
||||
transformers==4.50.1
|
||||
einops
|
||||
gradio
|
||||
inflect
|
||||
WeTextProcessing
|
||||
addicts
|
||||
modelscope==1.22.0
|
||||
simplejson
|
||||
datasets==2.18.0
|
||||
addicts
|
||||
sortedcontainers
|
||||
librosa
|
||||
huggingface-hub
|
||||
spaces
|
||||
5
src/voxcpm/__init__.py
Normal file
5
src/voxcpm/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .core import VoxCPM
|
||||
|
||||
__all__ = [
|
||||
"VoxCPM",
|
||||
]
|
||||
292
src/voxcpm/cli.py
Normal file
292
src/voxcpm/cli.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
VoxCPM Command Line Interface
|
||||
|
||||
Unified CLI for voice cloning, direct TTS synthesis, and batch processing.
|
||||
|
||||
Usage examples:
|
||||
# Direct synthesis (single sample)
|
||||
voxcpm --text "Hello world" --output output.wav
|
||||
|
||||
# Voice cloning (with reference audio and text)
|
||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output output.wav --denoise
|
||||
|
||||
# Batch processing (each line in the file is one sample)
|
||||
voxcpm --input texts.txt --output-dir ./outputs/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||
"""Validate that a file exists."""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{file_type} '{file_path}' does not exist")
|
||||
return path
|
||||
|
||||
|
||||
def validate_output_path(output_path: str) -> Path:
|
||||
"""Validate the output path and create parent directories if needed."""
|
||||
path = Path(output_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def load_model(args) -> VoxCPM:
|
||||
"""Load VoxCPM model.
|
||||
|
||||
Prefer --model-path if provided; otherwise use from_pretrained (Hub).
|
||||
"""
|
||||
print("Loading VoxCPM model...")
|
||||
|
||||
# 兼容旧参数:ZIPENHANCER_MODEL_PATH 环境变量作为默认
|
||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
model = VoxCPM(
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (local): {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||
try:
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
|
||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
except Exception as e:
|
||||
print(f"Failed to load model (from_pretrained): {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def cmd_clone(args):
|
||||
"""Voice cloning command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_audio:
|
||||
print("Error: Voice cloning requires a reference audio (--prompt-audio)")
|
||||
sys.exit(1)
|
||||
|
||||
if not args.prompt_text:
|
||||
print("Error: Voice cloning requires a reference text (--prompt-text)")
|
||||
sys.exit(1)
|
||||
|
||||
# Validate files
|
||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
||||
output_path = validate_output_path(args.output)
|
||||
|
||||
# Load model
|
||||
model = load_model(args)
|
||||
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
print(f"Reference audio: {prompt_audio_path}")
|
||||
print(f"Reference text: {args.prompt_text}")
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=str(prompt_audio_path),
|
||||
prompt_text=args.prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
def cmd_synthesize(args):
|
||||
"""Direct TTS synthesis command."""
|
||||
# Validate inputs
|
||||
if not args.text:
|
||||
print("Error: Please provide text to synthesize (--text)")
|
||||
sys.exit(1)
|
||||
# Validate output path
|
||||
output_path = validate_output_path(args.output)
|
||||
# Load model
|
||||
model = load_model(args)
|
||||
# Generate audio
|
||||
print(f"Synthesizing text: {args.text}")
|
||||
|
||||
audio_array = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=None,
|
||||
prompt_text=None,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=False # 无参考音频时不需要降噪
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
def cmd_batch(args):
|
||||
"""Batch synthesis command."""
|
||||
# Validate input file
|
||||
input_file = validate_file_exists(args.input, "input file")
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
texts = [line.strip() for line in f if line.strip()]
|
||||
except Exception as e:
|
||||
print(f"Failed to read input file: {e}")
|
||||
sys.exit(1)
|
||||
if not texts:
|
||||
print("Error: Input file is empty or contains no valid lines")
|
||||
sys.exit(1)
|
||||
print(f"Found {len(texts)} lines to process")
|
||||
|
||||
model = load_model(args)
|
||||
prompt_audio_path = None
|
||||
if args.prompt_audio:
|
||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
||||
|
||||
success_count = 0
|
||||
for i, text in enumerate(texts, 1):
|
||||
print(f"\nProcessing {i}/{len(texts)}: {text[:50]}...")
|
||||
|
||||
try:
|
||||
audio_array = model.generate(
|
||||
text=text,
|
||||
prompt_wav_path=prompt_audio_path,
|
||||
prompt_text=args.prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
normalize=args.normalize,
|
||||
denoise=args.denoise and prompt_audio_path is not None
|
||||
)
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
sf.write(str(output_file), audio_array, 16000)
|
||||
|
||||
duration = len(audio_array) / 16000
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f" Failed: {e}")
|
||||
continue
|
||||
|
||||
print(f"\nBatch finished: {success_count}/{len(texts)} succeeded")
|
||||
|
||||
def _build_unified_parser():
|
||||
"""Build unified argument parser (no subcommands, route by args)."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="VoxCPM CLI (single parser) - voice cloning, direct TTS, and batch processing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Direct synthesis (single sample)
|
||||
voxcpm --text "Hello world" --output out.wav
|
||||
|
||||
# Voice cloning (reference audio + text)
|
||||
voxcpm --text "Hello world" --prompt-audio voice.wav --prompt-text "reference text" --output out.wav --denoise
|
||||
|
||||
# Batch processing
|
||||
voxcpm --input texts.txt --output-dir ./outs
|
||||
|
||||
# Select model (from Hub)
|
||||
voxcpm --text "Hello" --output out.wav --hf-model-id openbmb/VoxCPM-0.5B
|
||||
"""
|
||||
)
|
||||
|
||||
# Task selection (automatic routing by presence of args)
|
||||
parser.add_argument("--input", "-i", help="Input text file (one line per sample)")
|
||||
parser.add_argument("--output-dir", "-od", help="Output directory (for batch mode)")
|
||||
parser.add_argument("--text", "-t", help="Text to synthesize (single-sample mode)")
|
||||
parser.add_argument("--output", "-o", help="Output audio file path (single-sample mode)")
|
||||
|
||||
# Prompt audio (for voice cloning)
|
||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||
|
||||
# Generation parameters
|
||||
parser.add_argument("--cfg-value", type=float, default=2.0, help="CFG guidance scale (default: 2.0)")
|
||||
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (default: 10)")
|
||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||
|
||||
# Model loading parameters
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
"""Unified CLI entrypoint: route by provided arguments."""
|
||||
parser = _build_unified_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Routing: prefer batch → single (clone/direct)
|
||||
if args.input:
|
||||
if not args.output_dir:
|
||||
print("Error: Batch mode requires --output-dir")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
return cmd_batch(args)
|
||||
|
||||
# Single-sample mode
|
||||
if not args.text or not args.output:
|
||||
print("Error: Single-sample mode requires --text and --output")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
# If prompt audio+text provided → voice cloning
|
||||
if args.prompt_audio or args.prompt_text:
|
||||
if not args.prompt_audio or not args.prompt_text:
|
||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||
sys.exit(1)
|
||||
return cmd_clone(args)
|
||||
|
||||
# Otherwise → direct synthesis
|
||||
return cmd_synthesize(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
181
src/voxcpm/core.py
Normal file
181
src/voxcpm/core.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
import os
|
||||
import tempfile
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from huggingface_hub import snapshot_download
|
||||
from .model.voxcpm import VoxCPMModel
|
||||
from .utils.text_normalize import TextNormalizer
|
||||
|
||||
|
||||
class VoxCPM:
|
||||
def __init__(self,
|
||||
voxcpm_model_path : str,
|
||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser : bool = True,
|
||||
):
|
||||
"""Initialize VoxCPM TTS pipeline.
|
||||
|
||||
Args:
|
||||
voxcpm_model_path: Local filesystem path to the VoxCPM model assets
|
||||
(weights, configs, etc.). Typically the directory returned by
|
||||
a prior download step.
|
||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
"""
|
||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
||||
self.text_normalizer = TextNormalizer()
|
||||
if enable_denoiser and zipenhancer_model_path is not None:
|
||||
self.denoiser = pipeline(
|
||||
Tasks.acoustic_noise_suppression,
|
||||
model=zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
print("Warm up VoxCPMModel...")
|
||||
self.tts_model.generate(
|
||||
target_text="Hello, this is the first test sentence."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
hf_model_id: str = "openbmb/VoxCPM-0.5B",
|
||||
load_denoiser: bool = True,
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
|
||||
Args:
|
||||
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo").
|
||||
load_denoiser: Whether to initialize the denoiser pipeline.
|
||||
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
||||
acoustic noise suppression.
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
|
||||
Returns:
|
||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||
the downloaded snapshot directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither a valid ``hf_model_id`` nor a resolvable
|
||||
``hf_model_id`` is provided.
|
||||
"""
|
||||
repo_id = hf_model_id
|
||||
if not repo_id or repo_id.strip() == "":
|
||||
raise ValueError("You must provide a valid hf_model_id")
|
||||
|
||||
local_path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
return cls(
|
||||
voxcpm_model_path=local_path,
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
)
|
||||
|
||||
def _normalize_loudness(self, wav_path: str):
|
||||
audio, sr = torchaudio.load(wav_path)
|
||||
loudness = torchaudio.functional.loudness(audio, sr)
|
||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
||||
torchaudio.save(wav_path, normalized_audio, sr)
|
||||
|
||||
def generate(self,
|
||||
text : str,
|
||||
prompt_wav_path : str = None,
|
||||
prompt_text : str = None,
|
||||
cfg_value : float = 2.0,
|
||||
inference_timesteps : int = 10,
|
||||
max_length : int = 4096,
|
||||
normalize : bool = True,
|
||||
denoise : bool = True,
|
||||
retry_badcase : bool = True,
|
||||
retry_badcase_max_times : int = 3,
|
||||
retry_badcase_ratio_threshold : float = 6.0,
|
||||
):
|
||||
"""Synthesize speech for the given text and return a single waveform.
|
||||
|
||||
This method optionally builds and reuses a prompt cache. If an external
|
||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
||||
the first generated result and reused for the remaining text chunks.
|
||||
|
||||
Args:
|
||||
text: Input text. Can include newlines; each non-empty line is
|
||||
treated as a sub-sentence.
|
||||
prompt_wav_path: Path to a reference audio file for prompting.
|
||||
prompt_text: Text content corresponding to the prompt audio.
|
||||
cfg_value: Guidance scale for the generation model.
|
||||
inference_timesteps: Number of inference steps.
|
||||
max_length: Maximum token length during generation.
|
||||
normalize: Whether to run text normalization before generation.
|
||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||
available.
|
||||
retry_badcase: Whether to retry badcase.
|
||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||
Returns:
|
||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||
"""
|
||||
texts = text.split("\n")
|
||||
texts = [t.strip() for t in texts if t.strip()]
|
||||
final_wav = []
|
||||
temp_prompt_wav_path = None
|
||||
|
||||
try:
|
||||
if prompt_wav_path is not None and prompt_text is not None:
|
||||
if denoise and self.denoiser is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
temp_prompt_wav_path = tmp_file.name
|
||||
|
||||
self.denoiser(prompt_wav_path, output_path=temp_prompt_wav_path)
|
||||
self._normalize_loudness(temp_prompt_wav_path)
|
||||
prompt_wav_path = temp_prompt_wav_path
|
||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text
|
||||
)
|
||||
else:
|
||||
fixed_prompt_cache = None # will be built from the first inference
|
||||
|
||||
for sub_text in texts:
|
||||
if sub_text.strip() == "":
|
||||
continue
|
||||
print("sub_text:", sub_text)
|
||||
if normalize:
|
||||
sub_text = self.text_normalizer.normalize(sub_text)
|
||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
||||
target_text=sub_text,
|
||||
prompt_cache=fixed_prompt_cache,
|
||||
min_len=2,
|
||||
max_len=max_length,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
retry_badcase=retry_badcase,
|
||||
retry_badcase_max_times=retry_badcase_max_times,
|
||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||
)
|
||||
if fixed_prompt_cache is None:
|
||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
||||
original_cache=None,
|
||||
new_text_token=target_text_token,
|
||||
new_audio_feat=generated_audio_feat
|
||||
)
|
||||
final_wav.append(wav)
|
||||
|
||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
||||
|
||||
finally:
|
||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||
try:
|
||||
os.unlink(temp_prompt_wav_path)
|
||||
except OSError:
|
||||
pass
|
||||
3
src/voxcpm/model/__init__.py
Normal file
3
src/voxcpm/model/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .voxcpm import VoxCPMModel
|
||||
|
||||
__all__ = ["VoxCPMModel"]
|
||||
122
src/voxcpm/model/utils.py
Normal file
122
src/voxcpm/model/utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
||||
"""Create a tokenizer wrapper that converts multi-character Chinese tokens to single characters.
|
||||
|
||||
This function creates a wrapper around the provided tokenizer that automatically
|
||||
splits multi-character Chinese tokens into individual characters. This is useful
|
||||
for ensuring consistent tokenization of Chinese text.
|
||||
|
||||
Args:
|
||||
tokenizer: The base tokenizer to wrap
|
||||
|
||||
Returns:
|
||||
A CharTokenizerWrapper instance that handles multi-character Chinese tokens
|
||||
|
||||
Example:
|
||||
>>> from transformers import LlamaTokenizerFast
|
||||
>>> tokenizer = LlamaTokenizerFast.from_pretrained("path/to/tokenizer")
|
||||
>>> wrapped_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
>>> tokens = wrapped_tokenizer("你好世界")
|
||||
"""
|
||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||
multichar_tokens = {
|
||||
token for token in tokenizer.vocab.keys()
|
||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||
}
|
||||
|
||||
class CharTokenizerWrapper:
|
||||
"""Wrapper class for tokenizers that handles multi-character Chinese tokens.
|
||||
|
||||
This wrapper automatically splits multi-character Chinese tokens into
|
||||
individual characters while preserving the original tokenizer's interface.
|
||||
"""
|
||||
|
||||
def __init__(self, base_tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Initialize the wrapper with a base tokenizer.
|
||||
|
||||
Args:
|
||||
base_tokenizer: The tokenizer to wrap
|
||||
"""
|
||||
self.tokenizer = base_tokenizer
|
||||
self.multichar_tokens = multichar_tokens
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""Tokenize text and split multi-character Chinese tokens into single characters.
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
Returns:
|
||||
List of processed tokens with multi-character Chinese tokens split
|
||||
|
||||
Example:
|
||||
>>> wrapper = CharTokenizerWrapper(tokenizer)
|
||||
>>> tokens = wrapper.tokenize("你好世界")
|
||||
>>> # Returns ["你", "好", "世", "界"] instead of ["你好", "世界"]
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"Expected string input, got {type(text)}")
|
||||
|
||||
tokens = self.tokenizer.tokenize(text, **kwargs)
|
||||
processed = []
|
||||
|
||||
for token in tokens:
|
||||
# Remove possible subword prefix
|
||||
clean_token = token.replace("▁", "")
|
||||
|
||||
if clean_token in self.multichar_tokens:
|
||||
# Split multi-character token into single characters
|
||||
chars = list(clean_token)
|
||||
processed.extend(chars)
|
||||
else:
|
||||
processed.append(token)
|
||||
|
||||
return processed
|
||||
|
||||
def __call__(self, text: str, **kwargs) -> List[int]:
|
||||
"""Call the tokenizer and return token IDs.
|
||||
|
||||
This method provides the same interface as the original tokenizer
|
||||
but with multi-character Chinese token handling.
|
||||
|
||||
Args:
|
||||
text: Input text to tokenize
|
||||
**kwargs: Additional arguments passed to the base tokenizer
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
|
||||
Raises:
|
||||
TypeError: If input is not a string
|
||||
ValueError: If tokenization fails
|
||||
"""
|
||||
try:
|
||||
tokens = self.tokenize(text, **kwargs)
|
||||
result = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise ValueError(f"Tokenization failed: {str(e)}") from e
|
||||
|
||||
return CharTokenizerWrapper(tokenizer)
|
||||
|
||||
|
||||
def get_dtype(dtype: str):
|
||||
if dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
elif dtype == "bf16":
|
||||
return torch.bfloat16
|
||||
elif dtype == "float16":
|
||||
return torch.float16
|
||||
elif dtype == "fp16":
|
||||
return torch.float16
|
||||
elif dtype == "float32":
|
||||
return torch.float32
|
||||
elif dtype == "fp32":
|
||||
return torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
605
src/voxcpm/model/voxcpm.py
Normal file
605
src/voxcpm/model/voxcpm.py
Normal file
@@ -0,0 +1,605 @@
|
||||
"""
|
||||
VoxCPM: A Tokenizer-free speech generation model
|
||||
|
||||
This module contains the main VoxCPM model implementation, including configuration classes
|
||||
and the core VoxCPMModel for text-to-speech generation.
|
||||
|
||||
Copyright 2025 OpenBMB
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
|
||||
class VoxCPMDitConfig(BaseModel):
|
||||
hidden_dim: int = 1024
|
||||
ffn_dim: int = 4096
|
||||
num_heads: int = 16
|
||||
num_layers: int = 4
|
||||
kv_channels: int = None
|
||||
|
||||
cfm_config: CfmConfig
|
||||
|
||||
|
||||
class VoxCPMConfig(BaseModel):
|
||||
lm_config: MiniCPM4Config
|
||||
patch_size: int = 2
|
||||
feat_dim: int = 64
|
||||
residual_lm_num_layers: int = 6
|
||||
scalar_quantization_latent_dim: int = 256
|
||||
scalar_quantization_scale: int = 9
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
self.device = "cpu"
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
self.audio_end_token = 102
|
||||
|
||||
# Residual Acoustic LM
|
||||
residual_lm_config = config.lm_config.model_copy(deep=True)
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
encoder_config.hidden_size = config.encoder_config.hidden_dim
|
||||
encoder_config.intermediate_size = config.encoder_config.ffn_dim
|
||||
encoder_config.num_attention_heads = config.encoder_config.num_heads
|
||||
encoder_config.num_hidden_layers = config.encoder_config.num_layers
|
||||
encoder_config.kv_channels = config.encoder_config.kv_channels
|
||||
encoder_config.vocab_size = 0
|
||||
self.feat_encoder = VoxCPMLocEnc(encoder_config, input_dim=config.feat_dim)
|
||||
|
||||
# Local DiT
|
||||
decoder_config = config.lm_config.model_copy(deep=True)
|
||||
decoder_config.hidden_size = config.dit_config.hidden_dim
|
||||
decoder_config.intermediate_size = config.dit_config.ffn_dim
|
||||
decoder_config.num_attention_heads = config.dit_config.num_heads
|
||||
decoder_config.num_hidden_layers = config.dit_config.num_layers
|
||||
decoder_config.kv_channels = config.dit_config.kv_channels
|
||||
decoder_config.vocab_size = 0
|
||||
self.feat_decoder = UnifiedCFM(
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
self.fsq_layer = ScalarQuantizationLayer(
|
||||
config.lm_config.hidden_size,
|
||||
config.lm_config.hidden_size,
|
||||
config.scalar_quantization_latent_dim,
|
||||
config.scalar_quantization_scale
|
||||
)
|
||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||
|
||||
# Stop Predictor
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
def optimize(self):
|
||||
if self.device == "cuda":
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
return self
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_text: str = "",
|
||||
prompt_wav_path: str = "",
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||
):
|
||||
if len(prompt_wav_path) == 0:
|
||||
text = target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
|
||||
|
||||
else:
|
||||
text = prompt_text + target_text
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_length = text_token.shape[0]
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat = torch.cat([audio_pad_feat, audio_feat], dim=0)
|
||||
text_mask = (
|
||||
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
audio_mask = (
|
||||
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
@torch.inference_mode()
|
||||
def build_prompt_cache(
|
||||
self,
|
||||
prompt_text: str,
|
||||
prompt_wav_path: str,
|
||||
):
|
||||
"""
|
||||
Build prompt cache for subsequent fast generation.
|
||||
|
||||
Args:
|
||||
prompt_text: prompt text (required)
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
return prompt_cache
|
||||
|
||||
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Merge original prompt cache with newly generated content to stabilize voice.
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_with_prompt_cache(
|
||||
self,
|
||||
target_text: str,
|
||||
prompt_cache: dict,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
retry_badcase: bool = False,
|
||||
retry_badcase_max_times: int = 3,
|
||||
retry_badcase_ratio_threshold: float = 6.0,
|
||||
):
|
||||
"""
|
||||
Generate audio using pre-built prompt cache.
|
||||
|
||||
Args:
|
||||
target_text: Text to convert to speech
|
||||
prompt_cache: Cache built by build_prompt_cache (can be None)
|
||||
min_len: Minimum audio length to avoid very short audio
|
||||
max_len: Maximum audio length
|
||||
inference_timesteps: Number of diffusion sampling steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
retry_badcase: Whether to retry on bad cases
|
||||
retry_badcase_max_times: Maximum retry attempts
|
||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||
|
||||
Returns:
|
||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
||||
"""
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_start_token],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_token.shape[0], self.patch_size, self.audio_vae.latent_dim),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
retry_badcase_times = 0
|
||||
while retry_badcase_times < retry_badcase_max_times:
|
||||
latent_pred, pred_audio_feat = self.inference(
|
||||
text_token,
|
||||
text_mask,
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
)
|
||||
if retry_badcase:
|
||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||
retry_badcase_times += 1
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
|
||||
return (
|
||||
decode_audio,
|
||||
target_text_token,
|
||||
pred_audio_feat
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
feat: torch.Tensor,
|
||||
feat_mask: torch.Tensor,
|
||||
min_len: int = 2,
|
||||
max_len: int = 2000,
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
This is the main inference loop that generates audio features
|
||||
using the language model and diffusion transformer.
|
||||
|
||||
Args:
|
||||
text: Input text tokens
|
||||
text_mask: Mask for text tokens
|
||||
feat: Input audio features
|
||||
feat_mask: Mask for audio features
|
||||
min_len: Minimum generation length
|
||||
max_len: Maximum generation length
|
||||
inference_timesteps: Number of diffusion steps
|
||||
cfg_value: Classifier-free guidance value
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Predicted latent features
|
||||
- Predicted audio feature sequence
|
||||
"""
|
||||
B, T, P, D = feat.shape
|
||||
|
||||
feat_embed = self.feat_encoder(feat) # [b, t, h_feat]
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
if self.config.lm_config.use_mup:
|
||||
scale_emb = self.config.lm_config.scale_emb
|
||||
else:
|
||||
scale_emb = 1.0
|
||||
|
||||
text_embed = self.base_lm.embed_tokens(text) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + feat_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
prefix_feat_cond = feat[:, -1, ...] # b, p, d
|
||||
pred_feat_seq = [] # b, t, p, d
|
||||
curr_embed = None
|
||||
|
||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||
inputs_embeds=combined_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.base_lm.kv_cache.fill_caches(kv_cache_tuple)
|
||||
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||
is_causal=True,
|
||||
)
|
||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||
|
||||
|
||||
for i in tqdm(range(max_len)):
|
||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||
dit_hidden = dit_hidden_1 + dit_hidden_2 # [b, h_dit]
|
||||
|
||||
pred_feat = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=prefix_feat_cond.transpose(1, 2).contiguous(),
|
||||
n_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
).transpose(
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
prefix_feat_cond = pred_feat
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
if i > min_len and stop_flag == 1:
|
||||
break
|
||||
|
||||
lm_hidden = self.base_lm.forward_step(
|
||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
|
||||
lm_hidden = self.fsq_layer(lm_hidden)
|
||||
residual_hidden = self.residual_lm.forward_step(
|
||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||
).clone()
|
||||
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
return model.to(model.device).eval().optimize()
|
||||
0
src/voxcpm/modules/__init__.py
Normal file
0
src/voxcpm/modules/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
1
src/voxcpm/modules/audiovae/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
359
src/voxcpm/modules/audiovae/audio_vae.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, **kwargs):
|
||||
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||
|
||||
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
|
||||
def forward(self, x):
|
||||
x_pad = F.pad(x, (self.__padding * 2, 0))
|
||||
return super().forward(x_pad)
|
||||
|
||||
|
||||
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__padding = padding
|
||||
self.__output_padding = output_padding
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||
|
||||
|
||||
def WNCausalConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||
|
||||
|
||||
# Scripting this brings model speed up 1.4x
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
class CausalResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel,
|
||||
dilation=dilation,
|
||||
padding=pad,
|
||||
groups=groups,
|
||||
),
|
||||
Snake1d(dim),
|
||||
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
assert pad == 0
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class CausalEncoderBlock(nn.Module):
|
||||
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||
super().__init__()
|
||||
input_dim = input_dim or output_dim // 2
|
||||
self.block = nn.Sequential(
|
||||
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||
Snake1d(input_dim),
|
||||
WNCausalConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class CausalEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
latent_dim: int = 32,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
depthwise: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
groups = d_model // 2 if depthwise else 1
|
||||
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||
|
||||
groups = d_model if depthwise else 1
|
||||
|
||||
# Create two convolution, for mu and logvar
|
||||
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
hidden_state = self.block(x)
|
||||
return {
|
||||
"hidden_state": hidden_state,
|
||||
"mu": self.fc_mu(hidden_state),
|
||||
"logvar": self.fc_logvar(hidden_state),
|
||||
}
|
||||
|
||||
|
||||
class NoiseBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||
h = self.linear(x)
|
||||
n = noise * h
|
||||
x = x + n
|
||||
return x
|
||||
|
||||
|
||||
class CausalDecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int = 16,
|
||||
output_dim: int = 8,
|
||||
stride: int = 1,
|
||||
groups=1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
Snake1d(input_dim),
|
||||
WNCausalTransposeConv1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
),
|
||||
]
|
||||
if use_noise_block:
|
||||
layers.append(NoiseBlock(output_dim))
|
||||
layers.extend(
|
||||
[
|
||||
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||
]
|
||||
)
|
||||
self.block = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class TransposeLastTwoDim(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.transpose(x, -1, -2)
|
||||
|
||||
|
||||
class CausalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
depthwise: bool = False,
|
||||
d_out: int = 1,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add first conv layer
|
||||
if depthwise:
|
||||
layers = [
|
||||
WNCausalConv1d(
|
||||
input_channel,
|
||||
input_channel,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
groups=input_channel,
|
||||
),
|
||||
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||
]
|
||||
else:
|
||||
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||
|
||||
# Add upsampling + MRF blocks
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
groups = output_dim if depthwise else 1
|
||||
layers += [
|
||||
CausalDecoderBlock(
|
||||
input_dim,
|
||||
output_dim,
|
||||
stride,
|
||||
groups=groups,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
]
|
||||
|
||||
# Add final conv layer
|
||||
layers += [
|
||||
Snake1d(output_dim),
|
||||
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
self.decoder_rates = decoder_rates
|
||||
self.depthwise = depthwise
|
||||
|
||||
self.use_noise_block = use_noise_block
|
||||
|
||||
if latent_dim is None:
|
||||
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = CausalEncoder(
|
||||
encoder_dim,
|
||||
latent_dim,
|
||||
encoder_rates,
|
||||
depthwise=depthwise,
|
||||
)
|
||||
|
||||
self.decoder = CausalDecoder(
|
||||
latent_dim,
|
||||
decoder_dim,
|
||||
decoder_rates,
|
||||
depthwise=depthwise,
|
||||
use_noise_block=use_noise_block,
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
self.chunk_size = math.prod(encoder_rates)
|
||||
|
||||
def preprocess(self, audio_data, sample_rate):
|
||||
if sample_rate is None:
|
||||
sample_rate = self.sample_rate
|
||||
assert sample_rate == self.sample_rate
|
||||
pad_to = self.hop_length
|
||||
length = audio_data.shape[-1]
|
||||
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||
|
||||
return audio_data
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
"""Decode given latent codes and return audio data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor[B x D x T]
|
||||
Quantized continuous representation of input
|
||||
length : int, optional
|
||||
Number of samples in output audio, by default None
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A dictionary with the following keys:
|
||||
"audio" : Tensor[B x 1 x length]
|
||||
Decoded audio data.
|
||||
"""
|
||||
return self.decoder(z)
|
||||
|
||||
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||
"""
|
||||
Args:
|
||||
audio_data: Tensor[B x 1 x T]
|
||||
sample_rate: int
|
||||
Returns:
|
||||
z: Tensor[B x D x T]
|
||||
"""
|
||||
if audio_data.ndim == 2:
|
||||
audio_data = audio_data.unsqueeze(1)
|
||||
|
||||
audio_data = self.preprocess(audio_data, sample_rate)
|
||||
return self.encoder(audio_data)["mu"]
|
||||
1
src/voxcpm/modules/layers/__init__.py
Normal file
1
src/voxcpm/modules/layers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .scalar_quantization_layer import ScalarQuantizationLayer
|
||||
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
26
src/voxcpm/modules/layers/scalar_quantization_layer.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ScalarQuantizationLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.latent_dim = latent_dim
|
||||
self.scale = scale
|
||||
|
||||
self.in_proj = nn.Linear(in_dim, latent_dim)
|
||||
self.out_proj = nn.Linear(latent_dim, out_dim)
|
||||
|
||||
def forward(self, hidden):
|
||||
hidden = self.in_proj(hidden)
|
||||
hidden = torch.tanh(hidden)
|
||||
|
||||
if self.training:
|
||||
quantized = torch.round(hidden * self.scale) / self.scale
|
||||
hidden = hidden + (quantized - hidden).detach()
|
||||
else:
|
||||
hidden = torch.round(hidden * self.scale) / self.scale
|
||||
|
||||
return self.out_proj(hidden)
|
||||
2
src/voxcpm/modules/locdit/__init__.py
Normal file
2
src/voxcpm/modules/locdit/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
114
src/voxcpm/modules/locdit/local_dit.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class SinusoidalPosEmb(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
if x.ndim < 1:
|
||||
x = x.unsqueeze(0)
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
out_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||
self.act = nn.SiLU()
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
sample = self.act(sample)
|
||||
sample = self.linear_2(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class VoxCPMLocDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniCPM4Config,
|
||||
in_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.config = config
|
||||
|
||||
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
self.delta_time_mlp = TimestepEmbedding(
|
||||
in_channels=config.hidden_size,
|
||||
time_embed_dim=config.hidden_size,
|
||||
)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||
self.decoder = MiniCPMModel(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
dt: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, T) tensor of inputs
|
||||
mu: (N, C) tensor of hidden embedding
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
cond: (N, C, T') tensor of prefix conditions
|
||||
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||
"""
|
||||
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||
|
||||
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||
prefix = cond.size(1)
|
||||
|
||||
t = self.time_embeddings(t).to(x.dtype)
|
||||
t = self.time_mlp(t)
|
||||
dt = self.time_embeddings(dt).to(x.dtype)
|
||||
dt = self.delta_time_mlp(dt)
|
||||
t = t + dt
|
||||
|
||||
x = torch.cat([(mu + t).unsqueeze(1), cond, x], dim=1)
|
||||
hidden, _ = self.decoder(x, is_causal=False)
|
||||
hidden = hidden[:, prefix + 1 :, :]
|
||||
hidden = self.out_proj(hidden)
|
||||
|
||||
return hidden.transpose(1, 2).contiguous()
|
||||
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
n_timesteps: int,
|
||||
patch_size: int,
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
def solve_euler(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t_span: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: Not used but kept for future purposes
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
|
||||
# not used now
|
||||
if not self.mean_mode:
|
||||
dt_in = torch.zeros_like(dt_in)
|
||||
cond_in[:b], cond_in[b:] = cond, cond
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
st_star = self.optimized_scale(positive_flat, negative_flat)
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
t = t - dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
1
src/voxcpm/modules/locenc/__init__.py
Normal file
1
src/voxcpm/modules/locenc/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .local_encoder import VoxCPMLocEnc
|
||||
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
30
src/voxcpm/modules/locenc/local_encoder.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class VoxCPMLocEnc(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
|
||||
self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)
|
||||
|
||||
assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
|
||||
self.encoder = MiniCPMModel(config)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, T, P, D]
|
||||
"""
|
||||
B, T, P, D = x.shape
|
||||
|
||||
x = self.in_proj(x)
|
||||
special_tokens = self.special_token.expand(B, T, 1, -1)
|
||||
x = torch.cat([special_tokens, x], dim=2)
|
||||
x = rearrange(x, "b t p c -> (b t) p c")
|
||||
outputs, _ = self.encoder(x, is_causal=False)
|
||||
cls_output = outputs[:, 0, :]
|
||||
|
||||
return rearrange(cls_output, "(b t) c -> b t c", b=B)
|
||||
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
3
src/voxcpm/modules/minicpm4/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .config import MiniCPM4Config
|
||||
from .model import MiniCPMModel
|
||||
from .cache import StaticKVCache
|
||||
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
47
src/voxcpm/modules/minicpm4/cache.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
|
||||
|
||||
class StaticKVCache:
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
dim_kv_head: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_length: int = 8192,
|
||||
):
|
||||
self.max_length = max_length
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.kv_cache = torch.zeros(
|
||||
2,
|
||||
num_layers,
|
||||
batch_size,
|
||||
num_kv_heads,
|
||||
max_length,
|
||||
dim_kv_head,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.current_length = 0
|
||||
|
||||
def get_layer_cache(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.kv_cache[0, layer_idx], self.kv_cache[1, layer_idx]
|
||||
|
||||
def step(self) -> int:
|
||||
if self.current_length >= self.max_length:
|
||||
raise ValueError("KV cache is full")
|
||||
|
||||
ret = self.current_length
|
||||
self.current_length += 1
|
||||
return ret
|
||||
|
||||
def fill_caches(self, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]]):
|
||||
self.current_length = kv_caches[0][0].size(2)
|
||||
self.kv_cache.zero_()
|
||||
for i in range(self.num_layers):
|
||||
self.kv_cache[0, i, :, :, : self.current_length, :] = kv_caches[i][0]
|
||||
self.kv_cache[1, i, :, :, : self.current_length, :] = kv_caches[i][1]
|
||||
29
src/voxcpm/modules/minicpm4/config.py
Normal file
29
src/voxcpm/modules/minicpm4/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class RopeScalingConfig(BaseModel):
|
||||
type: str
|
||||
long_factor: List[float]
|
||||
short_factor: List[float]
|
||||
original_max_position_embeddings: int
|
||||
|
||||
|
||||
class MiniCPM4Config(BaseModel):
|
||||
bos_token_id: int
|
||||
eos_token_id: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
max_position_embeddings: int
|
||||
num_attention_heads: int
|
||||
num_hidden_layers: int
|
||||
num_key_value_heads: int
|
||||
rms_norm_eps: float
|
||||
rope_scaling: RopeScalingConfig
|
||||
vocab_size: int
|
||||
use_mup: bool = True
|
||||
scale_emb: float
|
||||
dim_model_base: int
|
||||
scale_depth: float
|
||||
rope_theta: float
|
||||
kv_channels: int = None
|
||||
411
src/voxcpm/modules/minicpm4/model.py
Normal file
411
src/voxcpm/modules/minicpm4/model.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from .config import MiniCPM4Config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple
|
||||
import math
|
||||
from .cache import StaticKVCache
|
||||
|
||||
|
||||
def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
|
||||
old_dtype = hidden.dtype
|
||||
variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
|
||||
hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
|
||||
return hidden * weight
|
||||
|
||||
|
||||
class MiniCPMRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
MiniCPMRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
q: Tensor(batch_size, num_heads, seq_len, head_dim)
|
||||
k: Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
cos: Tensor(seq_len, head_dim)
|
||||
sin: Tensor(seq_len, head_dim)
|
||||
Returns:
|
||||
Tensor(batch_size, num_heads, seq_len, head_dim), Tensor(batch_size, num_key_value_heads, seq_len, head_dim)
|
||||
"""
|
||||
orig_dtype = q.dtype
|
||||
q = q.to(torch.float32)
|
||||
k = k.to(torch.float32)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
||||
|
||||
|
||||
class MiniCPMLongRoPE(nn.Module):
|
||||
"""MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dim = config.kv_channels if config.kv_channels else config.hidden_size // config.num_attention_heads
|
||||
self.base = config.rope_theta
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
self.short_factor = config.rope_scaling.short_factor
|
||||
self.long_factor = config.rope_scaling.long_factor
|
||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||
|
||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
||||
self.scaling_factor = math.sqrt(
|
||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
||||
)
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.max_seq_len_cached = 0
|
||||
|
||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=self.max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
"""设置cos和sin缓存"""
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
||||
|
||||
if seq_len > self.original_max_position_embeddings:
|
||||
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=device)
|
||||
else:
|
||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||
|
||||
freqs = torch.mul(
|
||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
||||
self.inv_freq.to(device=device).to(dtype)
|
||||
)
|
||||
|
||||
# 创建embeddings
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
self.cos_cached = emb.cos().to(dtype) * self.scaling_factor
|
||||
self.sin_cached = emb.sin().to(dtype) * self.scaling_factor
|
||||
|
||||
def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
position_ids: Tensor(seq_len) 或 Tensor(batch_size, seq_len)
|
||||
Returns:
|
||||
Tensor(seq_len, head_dim), Tensor(seq_len, head_dim)
|
||||
"""
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
class MiniCPMAttention(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = 10000.0
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
is_causal=is_causal,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
past_key_value = (key_states, value_states)
|
||||
return attn_output, past_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: int,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
bsz, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, 1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, 1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
key_cache[:, :, position_id, :] = key_states
|
||||
value_cache[:, :, position_id, :] = value_states
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask=attn_mask,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, self.num_heads * self.head_dim)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class MiniCPMDecoderLayer(nn.Module):
|
||||
def __init__(self, config: MiniCPM4Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = MiniCPMAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MiniCPMMLP(config)
|
||||
self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.scale_depth = config.scale_depth
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.use_mup = config.use_mup
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
is_causal: bool,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
position_ids (`torch.LongTensor`): position ids of shape `(batch_size, seq_len)`
|
||||
is_causal (`bool`): whether the attention mask is causal
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, present_key_value
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_emb: Tuple[torch.Tensor, torch.Tensor],
|
||||
position_id: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn.forward_step(
|
||||
hidden_states=hidden_states,
|
||||
position_emb=position_emb,
|
||||
position_id=position_id,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.use_mup:
|
||||
hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers))
|
||||
else:
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: MiniCPMConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: MiniCPM4Config):
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.config = config
|
||||
|
||||
if config.vocab_size > 0:
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
else:
|
||||
self.embed_tokens = nn.Identity()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rope_emb = MiniCPMLongRoPE(config)
|
||||
|
||||
self.kv_cache = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_causal: bool = True,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, seq_length, hidden_size)
|
||||
is_causal: bool, whether the attention mask is causal
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, seq_length, hidden_size)
|
||||
next_decoder_cache: List[(batch_size, num_heads, seq_length, head_dim), (batch_size, num_heads, seq_length, head_dim)]
|
||||
"""
|
||||
position_ids = torch.arange(0, inputs_embeds.size(1), dtype=torch.long, device=inputs_embeds.device)
|
||||
position_emb = self.rope_emb(position_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
next_decoder_cache = []
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
|
||||
hidden_states, this_cache = decoder_layer(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
is_causal,
|
||||
)
|
||||
next_decoder_cache.append(this_cache)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states, next_decoder_cache
|
||||
|
||||
def forward_step(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_id: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs_embeds: Tensor(batch_size, hidden_size)
|
||||
Returns:
|
||||
hidden_states: Tensor(batch_size, hidden_size)
|
||||
"""
|
||||
assert self.kv_cache is not None, "KV cache is not setup"
|
||||
|
||||
position_emb = self.rope_emb(position_id)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for i, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer.forward_step(
|
||||
hidden_states,
|
||||
position_emb,
|
||||
position_id,
|
||||
self.kv_cache.get_layer_cache(i),
|
||||
)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def setup_cache(self, batch_size: int, max_length: int, device, dtype: torch.dtype):
|
||||
self.kv_cache = StaticKVCache(
|
||||
num_layers=self.config.num_hidden_layers,
|
||||
num_kv_heads=self.config.num_key_value_heads,
|
||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
||||
batch_size=batch_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
max_length=max_length,
|
||||
)
|
||||
244
src/voxcpm/utils/text_normalize.py
Normal file
244
src/voxcpm/utils/text_normalize.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# some functions are copied from https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/utils/frontend_utils.py
|
||||
import re
|
||||
import regex
|
||||
import inflect
|
||||
from functools import partial
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
from tn.english.normalizer import Normalizer as EnNormalizer
|
||||
|
||||
def normal_cut_sentence(text):
|
||||
# 先处理括号内的逗号,将其替换为特殊标记
|
||||
text = re.sub(r'([((][^))]*)([,,])([^))]*[))])', r'\1&&&\3', text)
|
||||
text = re.sub('([。!,?\?])([^’”])',r'\1\n\2',text)#普通断句符号且后面没有引号
|
||||
text = re.sub('(\.{6})([^’”])',r'\1\n\2',text)#英文省略号且后面没有引号
|
||||
text = re.sub('(\…{2})([^’”])',r'\1\n\2',text)#中文省略号且后面没有引号
|
||||
text = re.sub('([. ,。!;?\?\.{6}\…{2}][’”])([^’”])',r'\1\n\2',text)#断句号+引号且后面没有引号
|
||||
# 处理英文句子的分隔
|
||||
text = re.sub(r'([.,!?])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号后面没有引号
|
||||
text = re.sub(r'([.!?][’”\'"])([^’”\'"])', r'\1\n\2', text) # 句号、感叹号、问号加引号后面的部分
|
||||
text = re.sub(r'([((][^))]*)(&&&)([^))]*[))])', r'\1,\3', text)
|
||||
text = [t for t in text.split("\n") if t]
|
||||
return text
|
||||
|
||||
|
||||
def cut_sentence_with_fix_length(text : str, length : int):
|
||||
sentences = normal_cut_sentence(text)
|
||||
cur_length = 0
|
||||
res = ""
|
||||
for sentence in sentences:
|
||||
if not sentence:
|
||||
continue
|
||||
if cur_length > length or cur_length + len(sentence) > length:
|
||||
yield res
|
||||
res = ""
|
||||
cur_length = 0
|
||||
res += sentence
|
||||
cur_length += len(sentence)
|
||||
if res:
|
||||
yield res
|
||||
|
||||
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
return bool(chinese_char_pattern.search(text))
|
||||
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
text = text.replace('√', '根号')
|
||||
text = text.replace('≈', '约等于')
|
||||
text = text.replace('<', '小于')
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', ' ').replace(')', ' ')
|
||||
text = text.replace('【', ' ').replace('】', ' ')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
|
||||
# spell Arabic numerals
|
||||
def spell_out_number(text: str, inflect_parser):
|
||||
new_text = []
|
||||
st = None
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
else:
|
||||
if st is None:
|
||||
st = i
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
||||
# 2. cal sentence len according to lang
|
||||
# 3. split sentence according to puncatation
|
||||
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
||||
def calc_utt_length(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text)
|
||||
else:
|
||||
return len(tokenize(_text))
|
||||
|
||||
def should_merge(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text) < merge_len
|
||||
else:
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
else:
|
||||
st = i + 1
|
||||
if len(utts) == 0:
|
||||
if lang == "zh":
|
||||
utts.append(text + '。')
|
||||
else:
|
||||
utts.append(text + '.')
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
||||
final_utts.append(cur_utt)
|
||||
cur_utt = ""
|
||||
cur_utt = cur_utt + utt
|
||||
if len(cur_utt) > 0:
|
||||
if should_merge(cur_utt) and len(final_utts) != 0:
|
||||
final_utts[-1] = final_utts[-1] + cur_utt
|
||||
else:
|
||||
final_utts.append(cur_utt)
|
||||
|
||||
return final_utts
|
||||
|
||||
|
||||
# remove blank between chinese character
|
||||
def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
def clean_markdown(md_text: str) -> str:
|
||||
# 去除代码块 ``` ```(包括多行)
|
||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||
|
||||
# 去除内联代码 `code`
|
||||
md_text = re.sub(r"`[^`]*`", "", md_text)
|
||||
|
||||
# 去除图片语法 
|
||||
md_text = re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", md_text)
|
||||
|
||||
# 去除链接但保留文本 [text](url) -> text
|
||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||
|
||||
# 替换无序列表符号
|
||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除HTML标签
|
||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||
|
||||
# 去除标题符号(#)
|
||||
md_text = re.sub(r"^#{1,6}\s*", "", md_text, flags=re.MULTILINE)
|
||||
|
||||
# 去除多余空格和空行
|
||||
md_text = re.sub(r"\n\s*\n", "\n", md_text) # 多余空行
|
||||
md_text = md_text.strip()
|
||||
|
||||
return md_text
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
# 去除 Markdown 语法
|
||||
text = clean_markdown(text)
|
||||
# 匹配并移除表情符号
|
||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
||||
# 去除换行符
|
||||
text = text.replace("\n", " ")
|
||||
text = text.replace("\t", " ")
|
||||
text = text.replace('"', "\“")
|
||||
return text
|
||||
|
||||
class TextNormalizer:
|
||||
def __init__(self, tokenizer=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False, full_to_half=False, remove_interjections=False, overwrite_cache=True)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.inflect_parser = inflect.engine()
|
||||
|
||||
def normalize(self, text, split=False):
|
||||
# 去除 Markdown 语法,去除表情符号,去除换行符
|
||||
lang = "zh" if contains_chinese(text) else "en"
|
||||
text = clean_text(text)
|
||||
if lang == "zh":
|
||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,]+$', '。', text)
|
||||
else:
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
if split is False:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_normalizer = TextNormalizer()
|
||||
text = r"""今天我们学习一元二次方程。一元二次方程的标准形式是:
|
||||
ax2+bx+c=0ax^2 + bx + c = 0ax2+bx+c=0
|
||||
其中,aaa、bbb 和 ccc 是常数,xxx 是变量。这个方程的解可以通过求根公式来找到。
|
||||
一元二次方程的解法有几种:
|
||||
- 因式分解法:通过将方程因式分解来求解。我们首先尝试将方程表达成两个括号的形式,解决方程的解。比如,方程x2−5x+6=0x^2 - 5x + 6 = 0x2−5x+6=0可以因式分解为(x−2)(x−3)=0(x - 2)(x - 3) = 0(x−2)(x−3)=0,因此根为2和3。
|
||||
- 配方法:通过配方将方程转化为完全平方的形式,从而解出。我们通过加上或减去适当的常数来完成这一过程,使得方程可以直接写成一个完全平方的形式。
|
||||
- 求根公式:我们可以使用求根公式直接求出方程的解。这个公式适用于所有的一元二次方程,即使我们无法通过因式分解或配方法来解决时,也能使用该公式。
|
||||
公式:x=−b±b2−4ac2ax = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a}x=2a−b±b2−4ac这个公式可以帮助我们求解任何一元二次方程的根。
|
||||
对于一元二次方程,我们需要了解判别式。判别式的作用是帮助我们判断方程的解的个数和性质。判别式 Δ\DeltaΔ 由下式给出:Δ=b2−4ac\Delta = b^2 - 4acΔ=b2−4ac 根据判别式的值,我们可以知道:
|
||||
- 如果 Δ>0\Delta > 0Δ>0,方程有两个不相等的实数解。这是因为判别式大于0时,根号内的值是正数,所以我们可以得到两个不同的解。
|
||||
- 如果 Δ=0\Delta = 0Δ=0,方程有一个实数解。这是因为根号内的值为零,导致两个解相等,也就是说方程有一个解。
|
||||
- 如果 Δ<0\Delta < 0Δ<0,方程没有实数解。这意味着根号内的值是负数,无法进行实数运算,因此方程没有实数解,可能有复数解。"""
|
||||
texts = ["这是一个公式 (a+b)³=a³+3a²b+3ab²+b³ S=(a×b)÷2", "这样的发展为AI仅仅作为“工具”这一观点提出了新的挑战,", "550 + 320 = 870千卡。", "解一元二次方程:3x^2+x-2=0", "你好啊"]
|
||||
texts = [text]
|
||||
for text in texts:
|
||||
text = text_normalizer.normalize(text)
|
||||
print(text)
|
||||
for t in cut_sentence_with_fix_length(text, 15):
|
||||
print(t)
|
||||
Reference in New Issue
Block a user