15 Commits
1.0.2 ... 1.0.4

Author SHA1 Message Date
刘鑫
41752dc0fa FX: Raising the Python version to avoid issues with Gradio failing to start. 2025-09-22 21:16:23 +08:00
xliucs
b0714adcaa Merge pull request #26 from AbrahamSanders/main
Add a streaming API for VoxCPM
2025-09-22 20:47:07 +08:00
AbrahamSanders
89f4d917a0 Update readme with streaming example 2025-09-19 17:09:30 -04:00
AbrahamSanders
5c5da0dbe6 Add a streaming API for VoxCPM 2025-09-19 16:56:11 -04:00
刘鑫
5f56d5ff5d FX: update README 2025-09-19 13:44:33 +08:00
xliucs
169c17ddfd Merge pull request #17 from MayDomine/main
add prompt-file option to set prompt text
2025-09-19 13:35:36 +08:00
MayDomine
996c69a1a8 add prompt-file option to set prompt text 2025-09-19 12:53:23 +08:00
刘鑫
dc6b6d1d1c Fx: capture compile error on Windows 2025-09-18 19:23:13 +08:00
刘鑫
cef6aefb3d remove \n from input text 2025-09-18 14:57:45 +08:00
周逸轩
1a46c5d1ad update README 2025-09-18 14:53:37 +08:00
周逸轩
5257ec3dc5 FX: noise point 2025-09-18 14:50:01 +08:00
刘鑫
bdd516b579 remove target text anotation 2025-09-18 13:07:43 +08:00
刘鑫
11568f0776 remove target text anotation 2025-09-18 12:58:27 +08:00
刘鑫
e5bcb735f0 Remove segment text logic 2025-09-18 12:02:37 +08:00
周逸轩
1fa9e2ca02 update README 2025-09-18 01:21:45 +08:00
7 changed files with 220 additions and 98 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
launch.json
__pycache__
voxcpm.egg-info

View File

@@ -50,7 +50,7 @@ By default, when you first run the script, the model will be downloaded automati
- Download VoxCPM-0.5B
```
from huggingface_hub import snapshot_download
snapshot_download("openbmb/VoxCPM-0.5B",local_files_only=local_files_only)
snapshot_download("openbmb/VoxCPM-0.5B")
```
- Download ZipEnhancer and SenseVoice-Small. We use ZipEnhancer to enhance speech prompts and SenseVoice-Small for speech prompt ASR in the web demo.
```
@@ -62,10 +62,12 @@ By default, when you first run the script, the model will be downloaded automati
### 2. Basic Usage
```python
import soundfile as sf
import numpy as np
from voxcpm import VoxCPM
model = VoxCPM.from_pretrained("openbmb/VoxCPM-0.5B")
# Non-streaming
wav = model.generate(
text="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech.",
prompt_wav_path=None, # optional: path to a prompt speech for voice cloning
@@ -81,6 +83,18 @@ wav = model.generate(
sf.write("output.wav", wav, 16000)
print("saved: output.wav")
# Streaming
chunks = []
for chunk in model.generate_streaming(
text = "Streaming text to speech is easy with VoxCPM!",
# supports same args as above
):
chunks.append(chunk)
wav = np.concatenate(chunks)
sf.write("output_streaming.wav", wav, 16000)
print("saved: output_streaming.wav")
```
### 3. CLI Usage
@@ -98,6 +112,13 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
--output out.wav \
--denoise
# (Optinal) Voice cloning (reference audio + transcript file)
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
--prompt-audio path/to/voice.wav \
--prompt-file "/path/to/text-file" \
--output out.wav \
--denoise
# 3) Batch processing (one text per line)
voxcpm --input examples/input.txt --output-dir outs
# (optional) Batch + cloning
@@ -267,6 +288,8 @@ This project is developed by the following institutions:
- <img src="assets/thuhcsi_logo.png" width="28px"> [THUHCSI](https://github.com/thuhcsi)
## ⭐ Star History
[![Star History Chart](https://api.star-history.com/svg?repos=OpenBMB/VoxCPM&type=Date)](https://star-history.com/#OpenBMB/VoxCPM&Date)
## 📚 Citation

9
app.py
View File

@@ -194,10 +194,6 @@ def create_demo_interface(demo: VoxCPMDemo):
**调低**:合成速度更快。
- **Higher** for better synthesis quality.
**调高**:合成质量更佳。
### Long Text (e.g., >5 min speech)|长文本 (如 >5分钟的合成语音)
While VoxCPM can handle long texts directly, we recommend using empty lines to break very long content into paragraphs; the model will then synthesize each paragraph individually.
虽然 VoxCPM 支持直接生成长文本,但如果目标文本过长,我们建议使用换行符将内容分段;模型将对每个段落分别合成。
""")
# Main controls
@@ -206,7 +202,7 @@ def create_demo_interface(demo: VoxCPMDemo):
prompt_wav = gr.Audio(
sources=["upload", 'microphone'],
type="filepath",
label="Prompt Speech",
label="Prompt Speech (Optional, or let VoxCPM improvise)",
value="./examples/example.wav",
)
DoDenoisePromptAudio = gr.Checkbox(
@@ -244,14 +240,13 @@ def create_demo_interface(demo: VoxCPMDemo):
text = gr.Textbox(
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
label="Target Text",
info="Default processing splits text on \\n into paragraphs; each is synthesized as a chunk and then concatenated into the final audio."
)
with gr.Row():
DoNormalizeText = gr.Checkbox(
value=False,
label="Text Normalization",
elem_id="chk_normalize",
info="We use WeTextPorcessing library to normalize the input text."
info="We use wetext library to normalize the input text."
)
audio_output = gr.Audio(label="Output Audio")

View File

@@ -20,12 +20,10 @@ classifiers = [
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
requires-python = ">=3.8"
requires-python = ">=3.10"
dependencies = [
"torch>=2.5.0",
"torchaudio>=2.5.0",
@@ -36,7 +34,7 @@ dependencies = [
"addict",
"wetext",
"modelscope>=1.22.0",
"datasets>=2,<4",
"datasets>=3,<4",
"huggingface-hub",
"pydantic",
"tqdm",
@@ -78,7 +76,7 @@ version_scheme = "post-release"
[tool.black]
line-length = 120
target-version = ['py38']
target-version = ['py310']
include = '\.pyi?$'
extend-exclude = '''
/(

View File

@@ -240,6 +240,7 @@ Examples:
# Prompt audio (for voice cloning)
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path")
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
parser.add_argument("--prompt-file", "-pf", help="Reference text file corresponding to the audio")
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement (denoising)")
# Generation parameters
@@ -279,6 +280,12 @@ def main():
# If prompt audio+text provided → voice cloning
if args.prompt_audio or args.prompt_text:
if not args.prompt_text and args.prompt_file:
assert os.path.isfile(args.prompt_file), "Prompt file does not exist or is not accessible."
with open(args.prompt_file, 'r', encoding='utf-8') as f:
args.prompt_text = f.read()
if not args.prompt_audio or not args.prompt_text:
print("Error: Voice cloning requires both --prompt-audio and --prompt-text")
sys.exit(1)

View File

@@ -1,7 +1,8 @@
import torch
import torchaudio
import os
import re
import tempfile
import numpy as np
from typing import Generator
from huggingface_hub import snapshot_download
from .model.voxcpm import VoxCPMModel
@@ -10,6 +11,7 @@ class VoxCPM:
voxcpm_model_path : str,
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser : bool = True,
optimize: bool = True,
):
"""Initialize VoxCPM TTS pipeline.
@@ -20,9 +22,10 @@ class VoxCPM:
zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
"""
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path)
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
self.text_normalizer = None
if enable_denoiser and zipenhancer_model_path is not None:
from .zipenhancer import ZipEnhancer
@@ -42,6 +45,7 @@ class VoxCPM:
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -53,6 +57,8 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
@@ -81,9 +87,16 @@ class VoxCPM:
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
**kwargs,
)
def generate(self,
def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
def _generate(self,
text : str,
prompt_wav_path : str = None,
prompt_text : str = None,
@@ -95,7 +108,8 @@ class VoxCPM:
retry_badcase : bool = True,
retry_badcase_max_times : int = 3,
retry_badcase_ratio_threshold : float = 6.0,
):
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
"""Synthesize speech for the given text and return a single waveform.
This method optionally builds and reuses a prompt cache. If an external
@@ -117,12 +131,24 @@ class VoxCPM:
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
streaming: Whether to return a generator of audio chunks.
Returns:
numpy.ndarray: 1D waveform array (float32) on CPU.
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generations step if ``streaming=True``,
otherwise yields a single array containing the final audio.
"""
texts = text.split("\n")
texts = [t.strip() for t in texts if t.strip()]
final_wav = []
if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string")
if prompt_wav_path is not None:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
text = text.replace("\n", " ")
text = re.sub(r'\s+', ' ', text)
temp_prompt_wav_path = None
try:
@@ -139,35 +165,27 @@ class VoxCPM:
else:
fixed_prompt_cache = None # will be built from the first inference
for sub_text in texts:
if sub_text.strip() == "":
continue
print("sub_text:", sub_text)
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer()
sub_text = self.text_normalizer.normalize(sub_text)
wav, target_text_token, generated_audio_feat = self.tts_model.generate_with_prompt_cache(
target_text=sub_text,
prompt_cache=fixed_prompt_cache,
min_len=2,
max_len=max_length,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
)
if fixed_prompt_cache is None:
fixed_prompt_cache = self.tts_model.merge_prompt_cache(
original_cache=None,
new_text_token=target_text_token,
new_audio_feat=generated_audio_feat
)
final_wav.append(wav)
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
return torch.cat(final_wav, dim=1).squeeze(0).cpu().numpy()
generate_result = self.tts_model._generate_with_prompt_cache(
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=2,
max_len=max_length,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
finally:
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):

View File

@@ -19,11 +19,12 @@ limitations under the License.
"""
import os
from typing import Dict, Optional, Tuple, Union
from typing import Tuple, Union, Generator, List
import torch
import torch.nn as nn
import torchaudio
import warnings
from einops import rearrange
from pydantic import BaseModel
from tqdm import tqdm
@@ -147,16 +148,23 @@ class VoxCPMModel(nn.Module):
self.sample_rate = audio_vae.sample_rate
def optimize(self):
def optimize(self, disable: bool = False):
try:
if disable:
raise ValueError("Optimization disabled by user")
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except:
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
except Exception as e:
print(f"Error: {e}")
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
@@ -164,8 +172,14 @@ class VoxCPMModel(nn.Module):
return self
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[torch.Tensor, None, None]:
return self._generate(*args, streaming=True, **kwargs)
@torch.inference_mode()
def generate(
def _generate(
self,
target_text: str,
prompt_text: str = "",
@@ -177,7 +191,11 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
):
streaming: bool = False,
) -> Generator[torch.Tensor, None, None]:
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
if len(prompt_wav_path) == 0:
text = target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
@@ -260,7 +278,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
@@ -269,17 +287,31 @@ class VoxCPMModel(nn.Module):
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, _ in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield decode_audio
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
else:
break
return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode()
def build_prompt_cache(
@@ -317,7 +349,7 @@ class VoxCPMModel(nn.Module):
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
@@ -369,8 +401,16 @@ class VoxCPMModel(nn.Module):
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def generate_with_prompt_cache(
def _generate_with_prompt_cache(
self,
target_text: str,
prompt_cache: dict,
@@ -381,7 +421,8 @@ class VoxCPMModel(nn.Module):
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
):
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""
Generate audio using pre-built prompt cache.
@@ -395,10 +436,17 @@ class VoxCPMModel(nn.Module):
retry_badcase: Whether to retry on bad cases
retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks
Returns:
tuple: (decoded audio tensor, new text tokens, new audio features)
Generator of Tuple containing:
- Decoded audio tensor for the current step if ``streaming=True``, else final decoded audio tensor
- Tensor of new text tokens
- New audio features up to the current step as a List if ``streaming=True``, else as a concatenated Tensor
"""
if retry_badcase and streaming:
warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
retry_badcase = False
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
@@ -443,7 +491,7 @@ class VoxCPMModel(nn.Module):
target_text_length = len(self.text_tokenizer(target_text))
retry_badcase_times = 0
while retry_badcase_times < retry_badcase_max_times:
latent_pred, pred_audio_feat = self.inference(
inference_result = self._inference(
text_token,
text_mask,
audio_feat,
@@ -452,26 +500,48 @@ class VoxCPMModel(nn.Module):
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
if streaming:
patch_len = self.patch_size * self.chunk_size
for latent_pred, pred_audio_feat in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
break
else:
latent_pred, pred_audio_feat = next(inference_result)
if retry_badcase:
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...")
retry_badcase_times += 1
continue
else:
break
else:
break
else:
break
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
return (
decode_audio,
target_text_token,
pred_audio_feat
)
yield (
decode_audio,
target_text_token,
pred_audio_feat
)
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
return next(self._inference(*args, streaming=False, **kwargs))
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
return self._inference(*args, streaming=True, **kwargs)
@torch.inference_mode()
def inference(
def _inference(
self,
text: torch.Tensor,
text_mask: torch.Tensor,
@@ -481,7 +551,8 @@ class VoxCPMModel(nn.Module):
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
streaming: bool = False,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
This is the main inference loop that generates audio features
@@ -496,11 +567,12 @@ class VoxCPMModel(nn.Module):
max_len: Maximum generation length
inference_timesteps: Number of diffusion steps
cfg_value: Classifier-free guidance value
streaming: Whether to yield each step latent feature or just the final result
Returns:
Tuple containing:
- Predicted latent features
- Predicted audio feature sequence
Generator of Tuple containing:
- Predicted latent feature at the current step if ``streaming=True``, else final latent features
- Predicted audio feature sequence so far as a List if ``streaming=True``, else as a concatenated Tensor
"""
B, T, P, D = feat.shape
@@ -558,6 +630,12 @@ class VoxCPMModel(nn.Module):
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
prefix_feat_cond = pred_feat
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
if i > min_len and stop_flag == 1:
break
@@ -572,14 +650,14 @@ class VoxCPMModel(nn.Module):
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
).clone()
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token
return feat_pred, pred_feat_seq.squeeze(0).cpu()
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str):
def from_local(cls, path: str, optimize: bool = True):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
@@ -605,4 +683,4 @@ class VoxCPMModel(nn.Module):
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
return model.to(model.device).eval().optimize()
return model.to(model.device).eval().optimize(disable=not optimize)