1 Commits

Author SHA1 Message Date
刘鑫
81467f649f Modify lora inference api 2025-12-05 22:29:44 +08:00
12 changed files with 75 additions and 1616 deletions

View File

@@ -44,13 +44,13 @@ Unlike mainstream approaches that convert speech to discrete tokens, VoxCPM uses
### 📦 Model Versions ### 📦 Model Versions
See [Release Notes](docs/release_note.md) for details See [Release Notes](docs/release_note.md) for details
- **VoxCPM1.5** (Latest): - **VoxCPM1.5** (Latest):
- Model Params: 800M - Model Params: 750M
- Sampling rate of AudioVAE: 44100 - Sampling rate of AudioVAE: 44100
- Token rate in LM Backbone: 6.25Hz (patch-size=4) - Token rate in LM Backbone: 6.25Hz (patch-size=4)
- RTF in a single NVIDIA-RTX 4090 GPU: ~0.15 - RTF in a single NVIDIA-RTX 4090 GPU: ~0.15
- **VoxCPM-0.5B** (Original): - **VoxCPM-0.5B** (Original):
- Model Params: 640M - Model Params: 600M
- Sampling rate of AudioVAE: 16000 - Sampling rate of AudioVAE: 16000
- Token rate in LM Backbone: 12.5Hz (patch-size=2) - Token rate in LM Backbone: 12.5Hz (patch-size=2)
- RTF in a single NVIDIA-RTX 4090 GPU: 0.17 - RTF in a single NVIDIA-RTX 4090 GPU: 0.17
@@ -210,8 +210,6 @@ We're excited to see the VoxCPM community growing! Here are some amazing project
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU. - **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference. - **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server. - **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
- **[PR: LoRA finetune web UI (by Ayin1412)](https://github.com/OpenBMB/VoxCPM/pull/100)**
- **[voxcpm_rs](https://github.com/madushan1000/voxcpm_rs)** A re-implementation of VoxCPM-0.5B in Rust.
*Note: The projects are not officially maintained by OpenBMB.* *Note: The projects are not officially maintained by OpenBMB.*

24
app.py
View File

@@ -172,22 +172,22 @@ def create_demo_interface(demo: VoxCPMDemo):
with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"): with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"):
gr.Markdown(""" gr.Markdown("""
### Prompt Speech Enhancement参考语音降噪 ### Prompt Speech Enhancement参考语音降噪
- **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling. - **Enable** to remove background noise for a clean, studio-like voice, with an external ZipEnhancer component.
**启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz限制克隆上限 **启用**:通过 ZipEnhancer 组件消除背景噪音,获得更好的音质
- **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate. - **Disable** to preserve the original audio's background atmosphere.
**禁用**:保留原始音频的全部信息包括背景环境声最高支持44.1kHz的音频复刻 **禁用**:保留原始音频的背景环境声,如果想复刻相应声学环境
### Text Normalization文本正则化 ### Text Normalization文本正则化
- **Enable** to process general text with an external WeTextProcessing component. - **Enable** to process general text with an external WeTextProcessing component.
**启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理 **启用**:使用 WeTextProcessing 组件,可处理常见文本。
- **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it! - **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input ({HH AH0 L OW1}), try it!
**禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3}英文转CMUDict{HH AH0 L OW1})和公式符号合成,尝试一下! **禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如 {da4}{jia1})和公式符号合成,尝试一下!
### CFG ValueCFG 值 ### CFG ValueCFG 值
- **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input. - **Lower CFG** if the voice prompt sounds strained or expressive.
**调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题 **调低**:如果提示语音听起来不自然或过于夸张。
- **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input. - **Higher CFG** for better adherence to the prompt speech style or input text.
**调高**:为更好地贴合提示音频的风格或输入文本 或者极短文本输入出现稳定性问题 **调高**:为更好地贴合提示音频的风格或输入文本。
### Inference Timesteps推理时间步 ### Inference Timesteps推理时间步
- **Lower** for faster synthesis speed. - **Lower** for faster synthesis speed.
@@ -267,7 +267,7 @@ def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error
demo = VoxCPMDemo() demo = VoxCPMDemo()
interface = create_demo_interface(demo) interface = create_demo_interface(demo)
# Recommended to enable queue on Spaces for better throughput # Recommended to enable queue on Spaces for better throughput
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error) interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -19,8 +19,6 @@ tensorboard: /path/to/logs/finetune_lora
lambdas: lambdas:
loss/diff: 1.0 loss/diff: 1.0
loss/stop: 1.0 loss/stop: 1.0
# LoRA configuration
lora: lora:
enable_lm: true enable_lm: true
enable_dit: true enable_dit: true
@@ -28,9 +26,3 @@ lora:
r: 32 r: 32
alpha: 16 alpha: 16
dropout: 0.0 dropout: 0.0
# Distribution options (optional)
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
# hf_model_id: "openbmb/VoxCPM1.5"
# distribute: true

View File

@@ -19,8 +19,6 @@ tensorboard: /path/to/logs/finetune_lora
lambdas: lambdas:
loss/diff: 1.0 loss/diff: 1.0
loss/stop: 1.0 loss/stop: 1.0
# LoRA configuration
lora: lora:
enable_lm: true enable_lm: true
enable_dit: true enable_dit: true
@@ -28,9 +26,3 @@ lora:
r: 32 r: 32
alpha: 16 alpha: 16
dropout: 0.0 dropout: 0.0
# Distribution options (optional)
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
# hf_model_id: "openbmb/VoxCPM-0.5B"
# distribute: true

View File

@@ -19,7 +19,6 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
## Table of Contents ## Table of Contents
- [Quick Start: WebUI](#quick-start-webui)
- [Data Preparation](#data-preparation) - [Data Preparation](#data-preparation)
- [Full Fine-tuning](#full-fine-tuning) - [Full Fine-tuning](#full-fine-tuning)
- [LoRA Fine-tuning](#lora-fine-tuning) - [LoRA Fine-tuning](#lora-fine-tuning)
@@ -29,31 +28,6 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
--- ---
## Quick Start: WebUI
For users who prefer a graphical interface, we provide `lora_ft_webui.py` - a comprehensive WebUI for training and inference:
### Launch WebUI
```bash
python lora_ft_webui.py
```
Then open `http://localhost:7860` in your browser.
### Features
- **🚀 Training Tab**: Configure and start LoRA training with an intuitive interface
- Set training parameters (learning rate, batch size, LoRA rank, etc.)
- Monitor training progress in real-time
- Resume training from existing checkpoints
- **🎵 Inference Tab**: Generate audio with trained models
- Automatic base model loading from LoRA checkpoint config
- Voice cloning with automatic ASR (reference text recognition)
- Hot-swap between multiple LoRA models
- Zero-shot TTS without reference audio
## Data Preparation ## Data Preparation
Training data should be prepared as a JSONL manifest file, with one sample per line: Training data should be prepared as a JSONL manifest file, with one sample per line:
@@ -203,10 +177,6 @@ lora:
# Target modules # Target modules
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"] target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"] target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
# Distribution options (optional)
# hf_model_id: "openbmb/VoxCPM1.5" # HuggingFace ID
# distribute: true # If true, save hf_model_id in lora_config.json
``` ```
### LoRA Parameters ### LoRA Parameters
@@ -219,15 +189,6 @@ lora:
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` | | `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
| `target_modules_*` | Layer names to add LoRA | attention layers | | `target_modules_*` | Layer names to add LoRA | attention layers |
### Distribution Options (Optional)
| Parameter | Description | Default |
|-----------|-------------|---------|
| `hf_model_id` | HuggingFace model ID (e.g., `openbmb/VoxCPM1.5`) | `""` |
| `distribute` | If `true`, save `hf_model_id` as `base_model` in checkpoint; otherwise save local `pretrained_path` | `false` |
> **Note**: If `distribute: true`, `hf_model_id` is required.
### Training ### Training
```bash ```bash
@@ -241,37 +202,16 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
### Checkpoint Structure ### Checkpoint Structure
LoRA training saves LoRA parameters and configuration: LoRA training saves only LoRA parameters:
``` ```
checkpoints/finetune_lora/ checkpoints/finetune_lora/
└── step_0002000/ └── step_0002000/
├── lora_weights.safetensors # Only lora_A, lora_B parameters ├── lora_weights.safetensors # Only lora_A, lora_B parameters
├── lora_config.json # LoRA config + base model path
├── optimizer.pth ├── optimizer.pth
└── scheduler.pth └── scheduler.pth
``` ```
The `lora_config.json` contains:
```json
{
"base_model": "/path/to/VoxCPM1.5/",
"lora_config": {
"enable_lm": true,
"enable_dit": true,
"r": 32,
"alpha": 16,
...
}
}
```
The `base_model` field contains:
- Local path (default): when `distribute: false` or not set
- HuggingFace ID: when `distribute: true` (e.g., `"openbmb/VoxCPM1.5"`)
This allows loading LoRA checkpoints without the original training config file.
--- ---
## Inference ## Inference
@@ -300,10 +240,11 @@ python scripts/test_voxcpm_ft_infer.py \
### LoRA Inference ### LoRA Inference
LoRA inference only requires the checkpoint directory (base model path and LoRA config are read from `lora_config.json`): LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint:
```bash ```bash
python scripts/test_voxcpm_lora_infer.py \ python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "Hello, this is LoRA fine-tuned result." \ --text "Hello, this is LoRA fine-tuned result." \
--output lora_output.wav --output lora_output.wav
@@ -313,6 +254,7 @@ With voice cloning:
```bash ```bash
python scripts/test_voxcpm_lora_infer.py \ python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \ --lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "This is voice cloning with LoRA." \ --text "This is voice cloning with LoRA." \
--prompt_audio /path/to/reference.wav \ --prompt_audio /path/to/reference.wav \
@@ -320,16 +262,6 @@ python scripts/test_voxcpm_lora_infer.py \
--output cloned_output.wav --output cloned_output.wav
``` ```
Override base model path (optional):
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--base_model /path/to/another/VoxCPM1.5 \
--text "Use different base model." \
--output output.wav
```
--- ---
## LoRA Hot-swapping ## LoRA Hot-swapping
@@ -383,39 +315,20 @@ print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
lora_state = model.get_lora_state_dict() lora_state = model.get_lora_state_dict()
``` ```
### Simplified Usage (Load from lora_config.json) ### Simplified Usage (Auto LoRA Config)
If your checkpoint contains `lora_config.json` (saved by the training script), you can load everything automatically: If you only have LoRA weights and don't need custom config, just provide the path:
```python ```python
import json
from voxcpm.core import VoxCPM from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig
# Load config from checkpoint # Auto-create default LoRAConfig when only lora_weights_path is provided
lora_ckpt_dir = "/path/to/checkpoints/finetune_lora/step_0002000"
with open(f"{lora_ckpt_dir}/lora_config.json") as f:
lora_info = json.load(f)
base_model = lora_info["base_model"]
lora_cfg = LoRAConfig(**lora_info["lora_config"])
# Load model with LoRA
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=base_model, hf_model_id="openbmb/VoxCPM1.5",
lora_config=lora_cfg, lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
lora_weights_path=lora_ckpt_dir,
) )
``` ```
Or use the test script directly:
```bash
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
--text "Hello world"
```
### Method Reference ### Method Reference
| Method | Description | torch.compile Compatible | | Method | Description | torch.compile Compatible |
@@ -430,39 +343,33 @@ python scripts/test_voxcpm_lora_infer.py \
## FAQ ## FAQ
### 1. How Much Data is Needed for LoRA Fine-tuning to Converge to a Single Voice? ### 1. Out of Memory (OOM)
We have tested with 5 minutes and 10 minutes of data (all audio clips are 3-6s in length). In our experiments, both datasets converged to a single voice after 2000 training steps with default configurations. You can adjust the data amount and training configurations based on your available data and computational resources.
### 2. Out of Memory (OOM)
- Increase `grad_accum_steps` (gradient accumulation) - Increase `grad_accum_steps` (gradient accumulation)
- Decrease `batch_size` - Decrease `batch_size`
- Use LoRA fine-tuning instead of full fine-tuning - Use LoRA fine-tuning instead of full fine-tuning
- Decrease `max_batch_tokens` to filter long samples - Decrease `max_batch_tokens` to filter long samples
### 3. Poor LoRA Performance ### 2. Poor LoRA Performance
- Increase `r` (LoRA rank) - Increase `r` (LoRA rank)
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`) - Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
- Increase training steps - Increase training steps
- Add more target modules - Add more target modules
### 4. Training Not Converging ### 3. Training Not Converging
- Decrease `learning_rate` - Decrease `learning_rate`
- Increase `warmup_steps` - Increase `warmup_steps`
- Check data quality - Check data quality
### 5. LoRA Not Taking Effect at Inference ### 4. LoRA Not Taking Effect at Inference
- Check that `lora_config.json` exists in the checkpoint directory - Ensure inference config matches training config LoRA parameters
- Check `load_lora()` return value - `skipped_keys` should be empty - Check `load_lora()` return value - `skipped_keys` should be empty
- Verify `set_lora_enabled(True)` is called - Verify `set_lora_enabled(True)` is called
### 6. Checkpoint Loading Errors ### 5. Checkpoint Loading Errors
- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth` - Full fine-tuning: checkpoint directory should contain `model.safetensors`(or `pytorch_model.bin`), `config.json`, `audiovae.pth`
- LoRA: checkpoint directory should contain: - LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`)
- `lora_weights.safetensors` (or `lora_weights.ckpt`) - LoRA weights
- `lora_config.json` - LoRA config and base model path

View File

@@ -32,9 +32,6 @@ We reduced the token rate in LM backbone from 12.5Hz to 6.25Hz (LocEnc&LocDiT pa
- 📈 Provides a foundation for longer audio generation - 📈 Provides a foundation for longer audio generation
- 🏗️ Paves the way for training larger models in the future - 🏗️ Paves the way for training larger models in the future
**Model Architecture Clarification**: The core architecture of VoxCPM1.5 remains unchanged from the technical report. The key modification is adjusting the patch size of the local modules (LocEnc & LocDiT) from 2 to 4, which reduces the LM processing rate from 12.5Hz to 6.25Hz. Since the local modules now need to handle longer contexts, we expanded their network depth, resulting in a slightly larger overall model parameter count.
**Generation Speed Clarification**: Although the model parameters have increased, VoxCPM1.5 only requires 6.25 tokens to generate 1 second of audio (compared to 12.5 tokens in the previous version). While the displayed generation speed (xx it/s) may appear slower, the actual Real-Time Factor (RTF = audio duration / processing time) shows no difference or may even be faster.
## 🔧 Fine-tuning Support ## 🔧 Fine-tuning Support
@@ -85,7 +82,7 @@ We're continuously improving VoxCPM and working on exciting new features:
### Q: Has the stability issue been resolved? ### Q: Has the stability issue been resolved?
**A:** We have made stability optimizations in VoxCPM1.5, including improvements to the inference code logic, training data, and model architecture. Based on community feedback, we collected some stability issues such as: **A:** We have made stability optimizations in VoxCPM1.5, including improvements to the training data and model architecture. Based on community feedback, we collected some stability issues such as:
- Increased noise and reverberation - Increased noise and reverberation
- Audio artifacts (e.g., howling/squealing) - Audio artifacts (e.g., howling/squealing)
- Unstable speaking rate (speeding up) - Unstable speaking rate (speeding up)
@@ -93,11 +90,7 @@ We're continuously improving VoxCPM and working on exciting new features:
- Noise artifacts at the beginning and end of audio - Noise artifacts at the beginning and end of audio
- Synthesis issues with very short texts (e.g., "hello") - Synthesis issues with very short texts (e.g., "hello")
**What we've improved:** While we have made improvements to these issues, they have not been completely resolved and may still occasionally occur, especially with very long or highly expressive inputs. We continue to work on further stability improvements in future versions.
- By adjusting inference code logic and optimizing training data, we have largely fixed the beginning/ending artifacts.
- By reducing the LM processing rate (12.5Hz → 6.25Hz), we have improved stability on longer speech generation cases.
**What remains:** We acknowledge that long speech stability issues have not been completely resolved. Particularly for highly expressive or complex reference speech, error accumulation during autoregressive generation can still occur. We will continue to analyze and optimize this in future versions.
### Q: Does VoxCPM plan to support multilingual TTS? ### Q: Does VoxCPM plan to support multilingual TTS?

View File

@@ -23,10 +23,8 @@ This is the secret sauce that gives your audio its unique sound.
### 1. Cooking with a Prompt Speech (Following a Famous Recipe) ### 1. Cooking with a Prompt Speech (Following a Famous Recipe)
- A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated. - A prompt speech provides the desired acoustic characteristics for VoxCPM. The speaker's timbre, speaking style, and even the background sounds and ambiance will be replicated.
- **For a Clean, Denoising Voice:** - **For a Clean, Studio-Quality Voice:**
- ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling. - ✅ Enable "Prompt Speech Enhancement". This acts like a noise filter, removing background hiss and rumble to give you a pure, clean voice clone.
- **For High-Quality Audio Cloning (Up to 44.1kHz):**
- ❌ Disable "Prompt Speech Enhancement" to preserve all original audio information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate.
### 2. Cooking au Naturel (Letting the Model Improvise) ### 2. Cooking au Naturel (Letting the Model Improvise)
- If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4. - If no reference is provided, VoxCPM becomes a creative chef! It will infer a fitting speaking style based on the text itself, thanks to the text-smartness of its foundation model, MiniCPM-4.

File diff suppressed because it is too large Load Diff

View File

@@ -114,7 +114,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )

View File

@@ -5,6 +5,7 @@ LoRA inference test script.
Usage: Usage:
python scripts/test_voxcpm_lora_infer.py \ python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \ --lora_ckpt checkpoints/step_0002000 \
--text "Hello, this is LoRA finetuned result." \ --text "Hello, this is LoRA finetuned result." \
--output lora_test.wav --output lora_test.wav
@@ -12,39 +13,37 @@ Usage:
With voice cloning: With voice cloning:
python scripts/test_voxcpm_lora_infer.py \ python scripts/test_voxcpm_lora_infer.py \
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
--lora_ckpt checkpoints/step_0002000 \ --lora_ckpt checkpoints/step_0002000 \
--text "This is voice cloning result." \ --text "This is voice cloning result." \
--prompt_audio path/to/ref.wav \ --prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \ --prompt_text "Reference audio transcript" \
--output lora_clone.wav --output lora_clone.wav
Note: The script reads base_model path and lora_config from lora_config.json
in the checkpoint directory (saved automatically during training).
""" """
import argparse import argparse
import json
from pathlib import Path from pathlib import Path
import soundfile as sf import soundfile as sf
from voxcpm.core import VoxCPM from voxcpm.core import VoxCPM
from voxcpm.model.voxcpm import LoRAConfig from voxcpm.model.voxcpm import LoRAConfig
from voxcpm.training.config import load_yaml_config
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test") parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Training YAML config path (contains pretrained_path and lora config)",
)
parser.add_argument( parser.add_argument(
"--lora_ckpt", "--lora_ckpt",
type=str, type=str,
required=True, required=True,
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)", help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
)
parser.add_argument(
"--base_model",
type=str,
default="",
help="Optional: override base model path (default: read from lora_config.json)",
) )
parser.add_argument( parser.add_argument(
"--text", "--text",
@@ -99,44 +98,26 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
# 1. Check LoRA checkpoint directory # 1. Load YAML config
ckpt_dir = Path(args.lora_ckpt) cfg = load_yaml_config(args.config_path)
if not ckpt_dir.exists(): pretrained_path = cfg["pretrained_path"]
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}") lora_cfg_dict = cfg.get("lora", {}) or {}
# 2. Load lora_config.json from checkpoint
lora_config_path = ckpt_dir / "lora_config.json"
if not lora_config_path.exists():
raise FileNotFoundError(
f"lora_config.json not found in {ckpt_dir}. "
"Make sure the checkpoint was saved with the updated training script."
)
with open(lora_config_path, "r", encoding="utf-8") as f:
lora_info = json.load(f)
# Get base model path (command line arg overrides config)
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
if not pretrained_path:
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
# Get LoRA config
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
print(f"Loaded config from: {lora_config_path}") # 2. Check LoRA checkpoint
print(f" Base model: {pretrained_path}") ckpt_dir = args.lora_ckpt
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None") if not Path(ckpt_dir).exists():
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
# 3. Load model with LoRA (no denoiser) # 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}") print(f"[1/2] Loading model with LoRA: {pretrained_path}")
print(f" LoRA weights: {ckpt_dir}") print(f" LoRA weights: {ckpt_dir}")
model = VoxCPM.from_pretrained( model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path, hf_model_id=pretrained_path,
load_denoiser=False, load_denoiser=False,
optimize=True, optimize=True,
lora_config=lora_cfg, lora_config=lora_cfg,
lora_weights_path=str(ckpt_dir), lora_weights_path=ckpt_dir,
) )
# 4. Synthesize audio # 4. Synthesize audio
@@ -155,7 +136,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )
@@ -172,7 +153,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )
@@ -189,7 +170,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )
@@ -206,7 +187,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )
@@ -216,7 +197,7 @@ def main():
# === Test 5: Hot-reload LoRA (load_lora) === # === Test 5: Hot-reload LoRA (load_lora) ===
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...") print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
loaded, skipped = model.load_lora(ckpt_dir) loaded, skipped = model.load_lora(str(ckpt_dir))
print(f" Reloaded {len(loaded)} parameters") print(f" Reloaded {len(loaded)} parameters")
audio_np = model.generate( audio_np = model.generate(
text=args.text, text=args.text,
@@ -224,7 +205,7 @@ def main():
prompt_text=prompt_text, prompt_text=prompt_text,
cfg_value=args.cfg_value, cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps, inference_timesteps=args.inference_timesteps,
max_len=args.max_len, max_length=args.max_len,
normalize=args.normalize, normalize=args.normalize,
denoise=False, denoise=False,
) )

View File

@@ -14,8 +14,6 @@ import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch.optim import AdamW from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup from transformers import get_cosine_schedule_with_warmup
import signal
import os
try: try:
from safetensors.torch import save_file from safetensors.torch import save_file
@@ -58,16 +56,8 @@ def train(
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0}, lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None, lora: dict = None,
config_path: str = "", config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
): ):
_ = config_path _ = config_path
# Validate distribution options
if lora is not None and distribute and not hf_model_id:
raise ValueError("hf_model_id is required when distribute=True")
accelerator = Accelerator(amp=True) accelerator = Accelerator(amp=True)
save_dir = Path(save_path) save_dir = Path(save_path)
@@ -181,39 +171,6 @@ def train(
num_training_steps=total_training_steps, num_training_steps=total_training_steps,
) )
# Try to load checkpoint and resume training
start_step = 0
if accelerator.rank == 0:
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
# Resume tracker for signal handler to read current step
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.")
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch() # Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
grad_accum_steps = max(int(grad_accum_steps), 1) grad_accum_steps = max(int(grad_accum_steps), 1)
data_epoch = 0 data_epoch = 0
@@ -234,9 +191,7 @@ def train(
return next(train_iter) return next(train_iter)
with tracker.live(): with tracker.live():
for step in range(start_step, num_iters): for step in range(num_iters):
# update resume step so signal handler can save current progress
resume["step"] = step
tracker.step = step tracker.step = step
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
@@ -300,10 +255,10 @@ def train(
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas) validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if step % save_interval == 0 and accelerator.rank == 0: if step % save_interval == 0 and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute) save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
if accelerator.rank == 0: if accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute) save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
if writer: if writer:
writer.close() writer.close()
@@ -346,77 +301,7 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.train() model.train()
def load_checkpoint(model, optimizer, scheduler, save_dir: Path): def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
"""
Load the latest checkpoint if it exists.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config
# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")
# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
return resume_step
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
""" """
Save checkpoint with different strategies for full finetune vs LoRA: Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable) - Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
@@ -440,17 +325,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "lora_weights.safetensors") save_file(state_dict, folder / "lora_weights.safetensors")
else: else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt") torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
import json
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
lora_info = {
"base_model": base_model_to_save,
"lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg),
}
with open(folder / "lora_config.json", "w", encoding="utf-8") as f:
json.dump(lora_info, f, indent=2, ensure_ascii=False)
else: else:
# Full finetune: save non-vae weights to model.safetensors # Full finetune: save non-vae weights to model.safetensors
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")} state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
@@ -471,29 +345,6 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth") torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth") torch.save(scheduler.state_dict(), folder / "scheduler.pth")
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
latest_link = save_dir / "latest"
try:
if latest_link.exists() or latest_link.is_symlink():
# remove existing link or directory
if latest_link.is_dir() and not latest_link.is_symlink():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
# Create a symlink pointing to the new folder
os.symlink(str(folder), str(latest_link))
except Exception:
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
try:
if latest_link.exists():
if latest_link.is_dir():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
if __name__ == "__main__": if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config from voxcpm.training.config import load_yaml_config
@@ -508,3 +359,4 @@ if __name__ == "__main__":
# Otherwise use command line args (parsed by argbind) # Otherwise use command line args (parsed by argbind)
with argbind.scope(args): with argbind.scope(args):
train() train()

View File

@@ -55,12 +55,11 @@ class VoxCPM:
self.denoiser = ZipEnhancer(zipenhancer_model_path) self.denoiser = ZipEnhancer(zipenhancer_model_path)
else: else:
self.denoiser = None self.denoiser = None
if optimize: print("Warm up VoxCPMModel...")
print("Warm up VoxCPMModel...") self.tts_model.generate(
self.tts_model.generate( target_text="Hello, this is the first test sentence.",
target_text="Hello, this is the first test sentence.", max_len=10,
max_len=10, )
)
@classmethod @classmethod
def from_pretrained(cls, def from_pretrained(cls,