mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41752dc0fa | ||
|
|
b0714adcaa | ||
|
|
89f4d917a0 | ||
|
|
5c5da0dbe6 | ||
|
|
5f56d5ff5d | ||
|
|
169c17ddfd | ||
|
|
996c69a1a8 | ||
|
|
dc6b6d1d1c | ||
|
|
cef6aefb3d | ||
|
|
1a46c5d1ad | ||
|
|
5257ec3dc5 | ||
|
|
bdd516b579 | ||
|
|
11568f0776 | ||
|
|
e5bcb735f0 | ||
|
|
1fa9e2ca02 |
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
launch.json
|
||||||
|
__pycache__
|
||||||
|
voxcpm.egg-info
|
||||||
25
README.md
25
README.md
@@ -50,7 +50,7 @@ By default, when you first run the script, the model will be downloaded automati
|
|||||||
- Download VoxCPM-0.5B
|
- Download VoxCPM-0.5B
|
||||||
```
|
```
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
|
snapshot_download("openbmb/VoxCPM-0.5B")
|
||||||
```
|
```
|
||||||
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
|
||||||
```
|
```
|
||||||
@@ -62,10 +62,12 @@ By default, when you first run the script, the model will be downloaded automati
|
|||||||
### 2. Basic Usage
|
### 2. Basic Usage
|
||||||
```python
|
```python
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
from voxcpm import VoxCPM
|
from voxcpm import VoxCPM
|
||||||
|
|
||||||
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
wav = model.generate(
|
wav = model.generate(
|
||||||
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
|
||||||
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
|
||||||
@@ -81,6 +83,18 @@ wav = model.generate(
|
|||||||
|
|
||||||
sf.write("output.wav", wav, 16000)
|
sf.write("output.wav", wav, 16000)
|
||||||
print("saved: output.wav")
|
print("saved: output.wav")
|
||||||
|
|
||||||
|
# Streaming
|
||||||
|
chunks = []
|
||||||
|
for chunk in model.generate_streaming(
|
||||||
|
text = "Streaming text to speech is easy with VoxCPM!",
|
||||||
|
# supports same args as above
|
||||||
|
):
|
||||||
|
chunks.append(chunk)
|
||||||
|
wav = np.concatenate(chunks)
|
||||||
|
|
||||||
|
sf.write("output_streaming.wav", wav, 16000)
|
||||||
|
print("saved: output_streaming.wav")
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. CLI Usage
|
### 3. CLI Usage
|
||||||
@@ -98,6 +112,13 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
|
|||||||
--output out.wav \
|
--output out.wav \
|
||||||
--denoise
|
--denoise
|
||||||
|
|
||||||
|
# (Optinal) Voice cloning (reference audio + transcript file)
|
||||||
|
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||||
|
--prompt-audio path/to/voice.wav \
|
||||||
|
--prompt-file "/path/to/text-file" \
|
||||||
|
--output out.wav \
|
||||||
|
--denoise
|
||||||
|
|
||||||
# 3) Batch processing (one text per line)
|
# 3) Batch processing (one text per line)
|
||||||
voxcpm --input examples/input.txt --output-dir outs
|
voxcpm --input examples/input.txt --output-dir outs
|
||||||
# (optional) Batch + cloning
|
# (optional) Batch + cloning
|
||||||
@@ -267,6 +288,8 @@ This project is developed by the following institutions:
|
|||||||
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
|
||||||
|
|
||||||
|
|
||||||
|
## ⭐ Star History
|
||||||
|
[](https://star-history.com/#OpenBMB/VoxCPM&Date)
|
||||||
|
|
||||||
|
|
||||||
## 📚 Citation
|
## 📚 Citation
|
||||||
|
|||||||
9
app.py
9
app.py
@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
**调低**:合成速度更快。
|
**调低**:合成速度更快。
|
||||||
- **Higher** for better synthesis quality.
|
- **Higher** for better synthesis quality.
|
||||||
**调高**:合成质量更佳。
|
**调高**:合成质量更佳。
|
||||||
|
|
||||||
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
|
|
||||||
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
|
|
||||||
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Main controls
|
# Main controls
|
||||||
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
prompt_wav = gr.Audio(
|
prompt_wav = gr.Audio(
|
||||||
sources=["upload", 'microphone'],
|
sources=["upload", 'microphone'],
|
||||||
type="filepath",
|
type="filepath",
|
||||||
label="Prompt Speech",
|
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||||
value="./examples/example.wav",
|
value="./examples/example.wav",
|
||||||
)
|
)
|
||||||
DoDenoisePromptAudio = gr.Checkbox(
|
DoDenoisePromptAudio = gr.Checkbox(
|
||||||
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
text = gr.Textbox(
|
text = gr.Textbox(
|
||||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||||
label="Target Text",
|
label="Target Text",
|
||||||
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
|
|
||||||
)
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
DoNormalizeText = gr.Checkbox(
|
DoNormalizeText = gr.Checkbox(
|
||||||
value=False,
|
value=False,
|
||||||
label="Text Normalization",
|
label="Text Normalization",
|
||||||
elem_id="chk_normalize",
|
elem_id="chk_normalize",
|
||||||
info="We use WeTextPorcessing library to normalize the input text."
|
info="We use wetext library to normalize the input text."
|
||||||
)
|
)
|
||||||
audio_output = gr.Audio(label="Output Audio")
|
audio_output = gr.Audio(label="Output Audio")
|
||||||
|
|
||||||
|
|||||||
@@ -20,12 +20,10 @@ classifiers = [
|
|||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"torch>=2.5.0",
|
"torch>=2.5.0",
|
||||||
"torchaudio>=2.5.0",
|
"torchaudio>=2.5.0",
|
||||||
@@ -36,7 +34,7 @@ dependencies = [
|
|||||||
"addict",
|
"addict",
|
||||||
"wetext",
|
"wetext",
|
||||||
"modelscope>=1.22.0",
|
"modelscope>=1.22.0",
|
||||||
"datasets>=2,<4",
|
"datasets>=3,<4",
|
||||||
"huggingface-hub",
|
"huggingface-hub",
|
||||||
"pydantic",
|
"pydantic",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
@@ -78,7 +76,7 @@ version_scheme = "post-release"
|
|||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
target-version = ['py38']
|
target-version = ['py310']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
extend-exclude = '''
|
extend-exclude = '''
|
||||||
/(
|
/(
|
||||||
|
|||||||
@@ -240,6 +240,7 @@ Examples:
|
|||||||
# Prompt audio (for voice cloning)
|
# Prompt audio (for voice cloning)
|
||||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
|
||||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
||||||
|
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
|
||||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
|
||||||
|
|
||||||
# Generation parameters
|
# Generation parameters
|
||||||
@@ -279,6 +280,12 @@ def main():
|
|||||||
|
|
||||||
# If prompt audio+text provided → voice cloning
|
# If prompt audio+text provided → voice cloning
|
||||||
if args.prompt_audio or args.prompt_text:
|
if args.prompt_audio or args.prompt_text:
|
||||||
|
if not args.prompt_text and args.prompt_file:
|
||||||
|
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
|
||||||
|
|
||||||
|
with open(args.prompt_file, 'r', encoding='utf-8') as f:
|
||||||
|
args.prompt_text = f.read()
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
if not args.prompt_audio or not args.prompt_text:
|
||||||
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import torch
|
|
||||||
import torchaudio
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
from typing import Generator
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel
|
from .model.voxcpm import VoxCPMModel
|
||||||
|
|
||||||
@@ -10,6 +11,7 @@ class VoxCPM:
|
|||||||
voxcpm_model_path : str,
|
voxcpm_model_path : str,
|
||||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser : bool = True,
|
enable_denoiser : bool = True,
|
||||||
|
optimize: bool = True,
|
||||||
):
|
):
|
||||||
"""Initialize VoxCPM TTS pipeline.
|
"""Initialize VoxCPM TTS pipeline.
|
||||||
|
|
||||||
@@ -20,9 +22,10 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
zipenhancer_model_path: ModelScope acoustic noise suppression model
|
||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
from .zipenhancer import ZipEnhancer
|
from .zipenhancer import ZipEnhancer
|
||||||
@@ -42,6 +45,7 @@ class VoxCPM:
|
|||||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||||
|
|
||||||
@@ -53,6 +57,8 @@ class VoxCPM:
|
|||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
Kwargs:
|
||||||
|
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
|
||||||
@@ -81,9 +87,16 @@ class VoxCPM:
|
|||||||
voxcpm_model_path=local_path,
|
voxcpm_model_path=local_path,
|
||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(self,
|
def generate(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
|
def _generate(self,
|
||||||
text : str,
|
text : str,
|
||||||
prompt_wav_path : str = None,
|
prompt_wav_path : str = None,
|
||||||
prompt_text : str = None,
|
prompt_text : str = None,
|
||||||
@@ -95,7 +108,8 @@ class VoxCPM:
|
|||||||
retry_badcase : bool = True,
|
retry_badcase : bool = True,
|
||||||
retry_badcase_max_times : int = 3,
|
retry_badcase_max_times : int = 3,
|
||||||
retry_badcase_ratio_threshold : float = 6.0,
|
retry_badcase_ratio_threshold : float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[np.ndarray, None, None]:
|
||||||
"""Synthesize speech for the given text and return a single waveform.
|
"""Synthesize speech for the given text and return a single waveform.
|
||||||
|
|
||||||
This method optionally builds and reuses a prompt cache. If an external
|
This method optionally builds and reuses a prompt cache. If an external
|
||||||
@@ -117,13 +131,25 @@ class VoxCPM:
|
|||||||
retry_badcase: Whether to retry badcase.
|
retry_badcase: Whether to retry badcase.
|
||||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||||
|
streaming: Whether to return a generator of audio chunks.
|
||||||
Returns:
|
Returns:
|
||||||
numpy.ndarray: 1D waveform array (float32) on CPU.
|
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||||
|
Yields audio chunks for each generations step if ``streaming=True``,
|
||||||
|
otherwise yields a single array containing the final audio.
|
||||||
"""
|
"""
|
||||||
texts = text.split("\n")
|
if not text.strip() or not isinstance(text, str):
|
||||||
texts = [t.strip() for t in texts if t.strip()]
|
raise ValueError("target text must be a non-empty string")
|
||||||
final_wav = []
|
|
||||||
temp_prompt_wav_path = None
|
if prompt_wav_path is not None:
|
||||||
|
if not os.path.exists(prompt_wav_path):
|
||||||
|
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||||
|
|
||||||
|
if (prompt_wav_path is None) != (prompt_text is None):
|
||||||
|
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||||
|
|
||||||
|
text = text.replace("\n", " ")
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
temp_prompt_wav_path = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prompt_wav_path is not None and prompt_text is not None:
|
if prompt_wav_path is not None and prompt_text is not None:
|
||||||
@@ -139,35 +165,27 @@ class VoxCPM:
|
|||||||
else:
|
else:
|
||||||
fixed_prompt_cache = None # will be built from the first inference
|
fixed_prompt_cache = None # will be built from the first inference
|
||||||
|
|
||||||
for sub_text in texts:
|
if normalize:
|
||||||
if sub_text.strip() == "":
|
if self.text_normalizer is None:
|
||||||
continue
|
from .utils.text_normalize import TextNormalizer
|
||||||
print("sub_text:", sub_text)
|
self.text_normalizer = TextNormalizer()
|
||||||
if normalize:
|
text = self.text_normalizer.normalize(text)
|
||||||
if self.text_normalizer is None:
|
|
||||||
from .utils.text_normalize import TextNormalizer
|
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||||
self.text_normalizer = TextNormalizer()
|
target_text=text,
|
||||||
sub_text = self.text_normalizer.normalize(sub_text)
|
prompt_cache=fixed_prompt_cache,
|
||||||
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
|
min_len=2,
|
||||||
target_text=sub_text,
|
max_len=max_length,
|
||||||
prompt_cache=fixed_prompt_cache,
|
inference_timesteps=inference_timesteps,
|
||||||
min_len=2,
|
cfg_value=cfg_value,
|
||||||
max_len=max_length,
|
retry_badcase=retry_badcase,
|
||||||
inference_timesteps=inference_timesteps,
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
cfg_value=cfg_value,
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
retry_badcase=retry_badcase,
|
streaming=streaming,
|
||||||
retry_badcase_max_times=retry_badcase_max_times,
|
)
|
||||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
|
||||||
)
|
|
||||||
if fixed_prompt_cache is None:
|
|
||||||
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
|
|
||||||
original_cache=None,
|
|
||||||
new_text_token=target_text_token,
|
|
||||||
new_audio_feat=generated_audio_feat
|
|
||||||
)
|
|
||||||
final_wav.append(wav)
|
|
||||||
|
|
||||||
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
|
for wav, _, _ in generate_result:
|
||||||
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ limitations under the License.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Tuple, Union, Generator, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
import warnings
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -147,16 +148,23 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.sample_rate = audio_vae.sample_rate
|
self.sample_rate = audio_vae.sample_rate
|
||||||
|
|
||||||
|
|
||||||
def optimize(self):
|
def optimize(self, disable: bool = False):
|
||||||
try:
|
try:
|
||||||
|
if disable:
|
||||||
|
raise ValueError("Optimization disabled by user")
|
||||||
if self.device != "cuda":
|
if self.device != "cuda":
|
||||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
except:
|
||||||
|
raise ValueError("triton is not installed")
|
||||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||||
except:
|
except Exception as e:
|
||||||
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
print(f"Error: {e}")
|
||||||
|
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||||
self.base_lm.forward_step = self.base_lm.forward_step
|
self.base_lm.forward_step = self.base_lm.forward_step
|
||||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||||
self.feat_encoder_step = self.feat_encoder
|
self.feat_encoder_step = self.feat_encoder
|
||||||
@@ -164,8 +172,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
|
||||||
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_text: str = "",
|
prompt_text: str = "",
|
||||||
@@ -177,7 +191,11 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
if len(prompt_wav_path) == 0:
|
if len(prompt_wav_path) == 0:
|
||||||
text = target_text
|
text = target_text
|
||||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||||
@@ -260,7 +278,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -269,17 +287,31 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
if retry_badcase:
|
if streaming:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
patch_len = self.patch_size * self.chunk_size
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
for latent_pred, _ in inference_result:
|
||||||
retry_badcase_times += 1
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
continue
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
else:
|
yield decode_audio
|
||||||
break
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
|
if retry_badcase:
|
||||||
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
|
retry_badcase_times += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not streaming:
|
||||||
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
|
yield decode_audio
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def build_prompt_cache(
|
def build_prompt_cache(
|
||||||
@@ -317,7 +349,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||||
|
|
||||||
# extract audio features
|
# extract audio features
|
||||||
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
|
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||||
|
|
||||||
audio_feat = audio_feat.view(
|
audio_feat = audio_feat.view(
|
||||||
self.audio_vae.latent_dim,
|
self.audio_vae.latent_dim,
|
||||||
@@ -368,9 +400,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def generate_with_prompt_cache_streaming(
|
||||||
|
self, *args, **kwargs
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_with_prompt_cache(
|
def _generate_with_prompt_cache(
|
||||||
self,
|
self,
|
||||||
target_text: str,
|
target_text: str,
|
||||||
prompt_cache: dict,
|
prompt_cache: dict,
|
||||||
@@ -381,7 +421,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0,
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
):
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate audio using pre-built prompt cache.
|
Generate audio using pre-built prompt cache.
|
||||||
|
|
||||||
@@ -395,10 +436,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase: Whether to retry on bad cases
|
retry_badcase: Whether to retry on bad cases
|
||||||
retry_badcase_max_times: Maximum retry attempts
|
retry_badcase_max_times: Maximum retry attempts
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||||
|
streaming: Whether to return a generator of audio chunks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (decoded audio tensor, new text tokens, new audio features)
|
Generator of Tuple containing:
|
||||||
|
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
|
||||||
|
- Tensor of new text tokens
|
||||||
|
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
|
if retry_badcase and streaming:
|
||||||
|
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
|
||||||
|
retry_badcase = False
|
||||||
# get prompt from cache
|
# get prompt from cache
|
||||||
if prompt_cache is None:
|
if prompt_cache is None:
|
||||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||||
@@ -443,7 +491,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
target_text_length = len(self.text_tokenizer(target_text))
|
target_text_length = len(self.text_tokenizer(target_text))
|
||||||
retry_badcase_times = 0
|
retry_badcase_times = 0
|
||||||
while retry_badcase_times < retry_badcase_max_times:
|
while retry_badcase_times < retry_badcase_max_times:
|
||||||
latent_pred, pred_audio_feat = self.inference(
|
inference_result = self._inference(
|
||||||
text_token,
|
text_token,
|
||||||
text_mask,
|
text_mask,
|
||||||
audio_feat,
|
audio_feat,
|
||||||
@@ -452,26 +500,48 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
if retry_badcase:
|
if streaming:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
patch_len = self.patch_size * self.chunk_size
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
for latent_pred, pred_audio_feat in inference_result:
|
||||||
retry_badcase_times += 1
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
continue
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
|
yield (
|
||||||
|
decode_audio,
|
||||||
|
target_text_token,
|
||||||
|
pred_audio_feat
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
|
if retry_badcase:
|
||||||
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
|
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
|
||||||
|
retry_badcase_times += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
else:
|
if not streaming:
|
||||||
break
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||||
|
|
||||||
return (
|
yield (
|
||||||
decode_audio,
|
decode_audio,
|
||||||
target_text_token,
|
target_text_token,
|
||||||
pred_audio_feat
|
pred_audio_feat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return next(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
|
return self._inference(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference(
|
def _inference(
|
||||||
self,
|
self,
|
||||||
text: torch.Tensor,
|
text: torch.Tensor,
|
||||||
text_mask: torch.Tensor,
|
text_mask: torch.Tensor,
|
||||||
@@ -481,7 +551,8 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: int = 2000,
|
max_len: int = 2000,
|
||||||
inference_timesteps: int = 10,
|
inference_timesteps: int = 10,
|
||||||
cfg_value: float = 2.0,
|
cfg_value: float = 2.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
streaming: bool = False,
|
||||||
|
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""Core inference method for audio generation.
|
"""Core inference method for audio generation.
|
||||||
|
|
||||||
This is the main inference loop that generates audio features
|
This is the main inference loop that generates audio features
|
||||||
@@ -496,11 +567,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
max_len: Maximum generation length
|
max_len: Maximum generation length
|
||||||
inference_timesteps: Number of diffusion steps
|
inference_timesteps: Number of diffusion steps
|
||||||
cfg_value: Classifier-free guidance value
|
cfg_value: Classifier-free guidance value
|
||||||
|
streaming: Whether to yield each step latent feature or just the final result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Generator of Tuple containing:
|
||||||
- Predicted latent features
|
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
|
||||||
- Predicted audio feature sequence
|
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
|
||||||
"""
|
"""
|
||||||
B, T, P, D = feat.shape
|
B, T, P, D = feat.shape
|
||||||
|
|
||||||
@@ -557,6 +629,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
|
if streaming:
|
||||||
|
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||||
|
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||||
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
@@ -572,14 +650,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
if not streaming:
|
||||||
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
|
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
return feat_pred, pred_feat_seq.squeeze(0).cpu()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str):
|
def from_local(cls, path: str, optimize: bool = True):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
@@ -605,4 +683,4 @@ class VoxCPMModel(nn.Module):
|
|||||||
for kw, val in vae_state_dict.items():
|
for kw, val in vae_state_dict.items():
|
||||||
model_state_dict[f"audio_vae.{kw}"] = val
|
model_state_dict[f"audio_vae.{kw}"] = val
|
||||||
model.load_state_dict(model_state_dict, strict=True)
|
model.load_state_dict(model_state_dict, strict=True)
|
||||||
return model.to(model.device).eval().optimize()
|
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||||
|
|||||||
Reference in New Issue
Block a user