add lora funetine webUI; optimize lora save and load logic
This commit is contained in:
@@ -210,6 +210,8 @@ We're excited to see the VoxCPM community growing! Here are some amazing project
|
||||
- **[VoxCPM-NanoVLLM](https://github.com/a710128/nanovllm-voxcpm)** NanoVLLM integration for VoxCPM for faster, high-throughput inference on GPU.
|
||||
- **[VoxCPM-ONNX](https://github.com/bluryar/VoxCPM-ONNX)** ONNX export for VoxCPM supports faster CPU inference.
|
||||
- **[VoxCPMANE](https://github.com/0seba/VoxCPMANE)** VoxCPM TTS with Apple Neural Engine backend server.
|
||||
- **[PR: LoRA finetune web UI (by Ayin1412)](https://github.com/OpenBMB/VoxCPM/pull/100)**
|
||||
- **[voxcpm_rs](https://github.com/madushan1000/voxcpm_rs)** A re-implementation of VoxCPM-0.5B in Rust.
|
||||
|
||||
*Note: The projects are not officially maintained by OpenBMB.*
|
||||
|
||||
|
||||
2
app.py
2
app.py
@@ -267,7 +267,7 @@ def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
# Recommended to enable queue on Spaces for better throughput
|
||||
interface.queue(max_size=10).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -19,6 +19,8 @@ tensorboard: /path/to/logs/finetune_lora
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
|
||||
# LoRA configuration
|
||||
lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
@@ -26,3 +28,9 @@ lora:
|
||||
r: 32
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
# Distribution options (optional)
|
||||
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
|
||||
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
|
||||
# hf_model_id: "openbmb/VoxCPM1.5"
|
||||
# distribute: true
|
||||
|
||||
@@ -19,6 +19,8 @@ tensorboard: /path/to/logs/finetune_lora
|
||||
lambdas:
|
||||
loss/diff: 1.0
|
||||
loss/stop: 1.0
|
||||
|
||||
# LoRA configuration
|
||||
lora:
|
||||
enable_lm: true
|
||||
enable_dit: true
|
||||
@@ -26,3 +28,9 @@ lora:
|
||||
r: 32
|
||||
alpha: 16
|
||||
dropout: 0.0
|
||||
|
||||
# Distribution options (optional)
|
||||
# - If distribute=false (default): save pretrained_path as base_model in lora_config.json
|
||||
# - If distribute=true: save hf_model_id as base_model (hf_model_id is required)
|
||||
# hf_model_id: "openbmb/VoxCPM-0.5B"
|
||||
# distribute: true
|
||||
114
docs/finetune.md
114
docs/finetune.md
@@ -19,6 +19,7 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Quick Start: WebUI](#quick-start-webui)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Full Fine-tuning](#full-fine-tuning)
|
||||
- [LoRA Fine-tuning](#lora-fine-tuning)
|
||||
@@ -28,6 +29,31 @@ LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning method that:
|
||||
|
||||
---
|
||||
|
||||
## Quick Start: WebUI
|
||||
|
||||
For users who prefer a graphical interface, we provide `lora_ft_webui.py` - a comprehensive WebUI for training and inference:
|
||||
|
||||
### Launch WebUI
|
||||
|
||||
```bash
|
||||
python lora_ft_webui.py
|
||||
```
|
||||
|
||||
Then open `http://localhost:7860` in your browser.
|
||||
|
||||
### Features
|
||||
|
||||
- **🚀 Training Tab**: Configure and start LoRA training with an intuitive interface
|
||||
- Set training parameters (learning rate, batch size, LoRA rank, etc.)
|
||||
- Monitor training progress in real-time
|
||||
- Resume training from existing checkpoints
|
||||
|
||||
- **🎵 Inference Tab**: Generate audio with trained models
|
||||
- Automatic base model loading from LoRA checkpoint config
|
||||
- Voice cloning with automatic ASR (reference text recognition)
|
||||
- Hot-swap between multiple LoRA models
|
||||
- Zero-shot TTS without reference audio
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Training data should be prepared as a JSONL manifest file, with one sample per line:
|
||||
@@ -177,6 +203,10 @@ lora:
|
||||
# Target modules
|
||||
target_modules_lm: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit: ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
|
||||
# Distribution options (optional)
|
||||
# hf_model_id: "openbmb/VoxCPM1.5" # HuggingFace ID
|
||||
# distribute: true # If true, save hf_model_id in lora_config.json
|
||||
```
|
||||
|
||||
### LoRA Parameters
|
||||
@@ -189,6 +219,15 @@ lora:
|
||||
| `alpha` | Scaling factor, `scaling = alpha / r` | Usually `r/2` or `r` |
|
||||
| `target_modules_*` | Layer names to add LoRA | attention layers |
|
||||
|
||||
### Distribution Options (Optional)
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `hf_model_id` | HuggingFace model ID (e.g., `openbmb/VoxCPM1.5`) | `""` |
|
||||
| `distribute` | If `true`, save `hf_model_id` as `base_model` in checkpoint; otherwise save local `pretrained_path` | `false` |
|
||||
|
||||
> **Note**: If `distribute: true`, `hf_model_id` is required.
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
@@ -202,16 +241,37 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
|
||||
|
||||
### Checkpoint Structure
|
||||
|
||||
LoRA training saves only LoRA parameters:
|
||||
LoRA training saves LoRA parameters and configuration:
|
||||
|
||||
```
|
||||
checkpoints/finetune_lora/
|
||||
└── step_0002000/
|
||||
├── lora_weights.safetensors # Only lora_A, lora_B parameters
|
||||
├── lora_config.json # LoRA config + base model path
|
||||
├── optimizer.pth
|
||||
└── scheduler.pth
|
||||
```
|
||||
|
||||
The `lora_config.json` contains:
|
||||
```json
|
||||
{
|
||||
"base_model": "/path/to/VoxCPM1.5/",
|
||||
"lora_config": {
|
||||
"enable_lm": true,
|
||||
"enable_dit": true,
|
||||
"r": 32,
|
||||
"alpha": 16,
|
||||
...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `base_model` field contains:
|
||||
- Local path (default): when `distribute: false` or not set
|
||||
- HuggingFace ID: when `distribute: true` (e.g., `"openbmb/VoxCPM1.5"`)
|
||||
|
||||
This allows loading LoRA checkpoints without the original training config file.
|
||||
|
||||
---
|
||||
|
||||
## Inference
|
||||
@@ -240,11 +300,10 @@ python scripts/test_voxcpm_ft_infer.py \
|
||||
|
||||
### LoRA Inference
|
||||
|
||||
LoRA inference requires the training config (for LoRA structure) and LoRA checkpoint:
|
||||
LoRA inference only requires the checkpoint directory (base model path and LoRA config are read from `lora_config.json`):
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--text "Hello, this is LoRA fine-tuned result." \
|
||||
--output lora_output.wav
|
||||
@@ -254,7 +313,6 @@ With voice cloning:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm_v1.5/voxcpm_finetune_lora.yaml \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--text "This is voice cloning with LoRA." \
|
||||
--prompt_audio /path/to/reference.wav \
|
||||
@@ -262,6 +320,16 @@ python scripts/test_voxcpm_lora_infer.py \
|
||||
--output cloned_output.wav
|
||||
```
|
||||
|
||||
Override base model path (optional):
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--base_model /path/to/another/VoxCPM1.5 \
|
||||
--text "Use different base model." \
|
||||
--output output.wav
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## LoRA Hot-swapping
|
||||
@@ -315,20 +383,39 @@ print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||
lora_state = model.get_lora_state_dict()
|
||||
```
|
||||
|
||||
### Simplified Usage (Auto LoRA Config)
|
||||
### Simplified Usage (Load from lora_config.json)
|
||||
|
||||
If you only have LoRA weights and don't need custom config, just provide the path:
|
||||
If your checkpoint contains `lora_config.json` (saved by the training script), you can load everything automatically:
|
||||
|
||||
```python
|
||||
import json
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
|
||||
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
||||
# Load config from checkpoint
|
||||
lora_ckpt_dir = "/path/to/checkpoints/finetune_lora/step_0002000"
|
||||
with open(f"{lora_ckpt_dir}/lora_config.json") as f:
|
||||
lora_info = json.load(f)
|
||||
|
||||
base_model = lora_info["base_model"]
|
||||
lora_cfg = LoRAConfig(**lora_info["lora_config"])
|
||||
|
||||
# Load model with LoRA
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id="openbmb/VoxCPM1.5",
|
||||
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
||||
hf_model_id=base_model,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=lora_ckpt_dir,
|
||||
)
|
||||
```
|
||||
|
||||
Or use the test script directly:
|
||||
|
||||
```bash
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--lora_ckpt /path/to/checkpoints/finetune_lora/step_0002000 \
|
||||
--text "Hello world"
|
||||
```
|
||||
|
||||
### Method Reference
|
||||
|
||||
| Method | Description | torch.compile Compatible |
|
||||
@@ -354,7 +441,6 @@ model = VoxCPM.from_pretrained(
|
||||
|
||||
- Increase `r` (LoRA rank)
|
||||
- Adjust `alpha` (try `alpha = r/2` or `alpha = r`)
|
||||
- Ensure `enable_dit: true` (required for voice cloning)
|
||||
- Increase training steps
|
||||
- Add more target modules
|
||||
|
||||
@@ -366,11 +452,13 @@ model = VoxCPM.from_pretrained(
|
||||
|
||||
### 4. LoRA Not Taking Effect at Inference
|
||||
|
||||
- Ensure inference config matches training config LoRA parameters
|
||||
- Check that `lora_config.json` exists in the checkpoint directory
|
||||
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||
- Verify `set_lora_enabled(True)` is called
|
||||
|
||||
### 5. Checkpoint Loading Errors
|
||||
|
||||
- Full fine-tuning: checkpoint directory should contain `model.safetensors`(or `pytorch_model.bin`), `config.json`, `audiovae.pth`
|
||||
- LoRA: checkpoint directory should contain `lora_weights.safetensors` (or `lora_weights.ckpt`)
|
||||
- Full fine-tuning: checkpoint directory should contain `model.safetensors` (or `pytorch_model.bin`), `config.json`, `audiovae.pth`
|
||||
- LoRA: checkpoint directory should contain:
|
||||
- `lora_weights.safetensors` (or `lora_weights.ckpt`) - LoRA weights
|
||||
- `lora_config.json` - LoRA config and base model path
|
||||
|
||||
1253
lora_ft_webui.py
Normal file
1253
lora_ft_webui.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,7 +5,6 @@ LoRA inference test script.
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "Hello, this is LoRA finetuned result." \
|
||||
--output lora_test.wav
|
||||
@@ -13,37 +12,39 @@ Usage:
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "This is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output lora_clone.wav
|
||||
|
||||
Note: The script reads base_model path and lora_config from lora_config.json
|
||||
in the checkpoint directory (saved automatically during training).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training YAML config path (contains pretrained_path and lora config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_ckpt",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
|
||||
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: override base model path (default: read from lora_config.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
@@ -98,26 +99,44 @@ def parse_args():
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# 1. Load YAML config
|
||||
cfg = load_yaml_config(args.config_path)
|
||||
pretrained_path = cfg["pretrained_path"]
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Check LoRA checkpoint
|
||||
ckpt_dir = args.lora_ckpt
|
||||
if not Path(ckpt_dir).exists():
|
||||
# 1. Check LoRA checkpoint directory
|
||||
ckpt_dir = Path(args.lora_ckpt)
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
# 2. Load lora_config.json from checkpoint
|
||||
lora_config_path = ckpt_dir / "lora_config.json"
|
||||
if not lora_config_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"lora_config.json not found in {ckpt_dir}. "
|
||||
"Make sure the checkpoint was saved with the updated training script."
|
||||
)
|
||||
|
||||
with open(lora_config_path, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
|
||||
# Get base model path (command line arg overrides config)
|
||||
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
|
||||
if not pretrained_path:
|
||||
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
|
||||
|
||||
# Get LoRA config
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}")
|
||||
print(f" Base model: {pretrained_path}")
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None")
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=ckpt_dir,
|
||||
lora_weights_path=str(ckpt_dir),
|
||||
)
|
||||
|
||||
# 4. Synthesize audio
|
||||
@@ -197,7 +216,7 @@ def main():
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
|
||||
@@ -14,6 +14,8 @@ import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.optim import AdamW
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
import signal
|
||||
import os
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
@@ -56,8 +58,16 @@ def train(
|
||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
# Distribution options (for LoRA checkpoints)
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
):
|
||||
_ = config_path
|
||||
|
||||
# Validate distribution options
|
||||
if lora is not None and distribute and not hf_model_id:
|
||||
raise ValueError("hf_model_id is required when distribute=True")
|
||||
|
||||
accelerator = Accelerator(amp=True)
|
||||
|
||||
save_dir = Path(save_path)
|
||||
@@ -171,6 +181,39 @@ def train(
|
||||
num_training_steps=total_training_steps,
|
||||
)
|
||||
|
||||
# Try to load checkpoint and resume training
|
||||
start_step = 0
|
||||
if accelerator.rank == 0:
|
||||
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
|
||||
# Broadcast start_step to all processes
|
||||
if hasattr(accelerator, 'all_reduce'):
|
||||
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
|
||||
accelerator.all_reduce(start_step_tensor)
|
||||
start_step = int(start_step_tensor.item())
|
||||
|
||||
if start_step > 0 and accelerator.rank == 0:
|
||||
tracker.print(f"Resuming training from step {start_step}")
|
||||
|
||||
# Resume tracker for signal handler to read current step
|
||||
resume = {"step": start_step}
|
||||
|
||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
|
||||
try:
|
||||
cur_step = int(_resume.get("step", start_step))
|
||||
except Exception:
|
||||
cur_step = start_step
|
||||
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
|
||||
try:
|
||||
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
|
||||
print("Checkpoint saved. Exiting.")
|
||||
except Exception as e:
|
||||
print(f"Error saving checkpoint on signal: {e}")
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, _signal_handler)
|
||||
signal.signal(signal.SIGINT, _signal_handler)
|
||||
|
||||
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
|
||||
grad_accum_steps = max(int(grad_accum_steps), 1)
|
||||
data_epoch = 0
|
||||
@@ -191,7 +234,9 @@ def train(
|
||||
return next(train_iter)
|
||||
|
||||
with tracker.live():
|
||||
for step in range(num_iters):
|
||||
for step in range(start_step, num_iters):
|
||||
# update resume step so signal handler can save current progress
|
||||
resume["step"] = step
|
||||
tracker.step = step
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
@@ -255,10 +300,10 @@ def train(
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
|
||||
|
||||
if step % save_interval == 0 and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||
|
||||
if accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute)
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
@@ -301,7 +346,77 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
||||
model.train()
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
|
||||
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
|
||||
"""
|
||||
Load the latest checkpoint if it exists.
|
||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||
"""
|
||||
latest_folder = save_dir / "latest"
|
||||
if not latest_folder.exists():
|
||||
return 0
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
# Load model weights
|
||||
if lora_cfg is not None:
|
||||
# LoRA: load lora_weights
|
||||
lora_weights_path = latest_folder / "lora_weights.safetensors"
|
||||
if not lora_weights_path.exists():
|
||||
lora_weights_path = latest_folder / "lora_weights.ckpt"
|
||||
|
||||
if lora_weights_path.exists():
|
||||
if lora_weights_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
state_dict = load_file(str(lora_weights_path))
|
||||
else:
|
||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
# Load only lora weights
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded LoRA weights from {lora_weights_path}")
|
||||
else:
|
||||
# Full finetune: load model.safetensors or pytorch_model.bin
|
||||
model_path = latest_folder / "model.safetensors"
|
||||
if not model_path.exists():
|
||||
model_path = latest_folder / "pytorch_model.bin"
|
||||
|
||||
if model_path.exists():
|
||||
if model_path.suffix == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
state_dict = load_file(str(model_path))
|
||||
else:
|
||||
ckpt = torch.load(model_path, map_location="cpu")
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
|
||||
unwrapped.load_state_dict(state_dict, strict=False)
|
||||
print(f"Loaded model weights from {model_path}")
|
||||
|
||||
# Load optimizer state
|
||||
optimizer_path = latest_folder / "optimizer.pth"
|
||||
if optimizer_path.exists():
|
||||
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
|
||||
print(f"Loaded optimizer state from {optimizer_path}")
|
||||
|
||||
# Load scheduler state
|
||||
scheduler_path = latest_folder / "scheduler.pth"
|
||||
if scheduler_path.exists():
|
||||
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
|
||||
print(f"Loaded scheduler state from {scheduler_path}")
|
||||
|
||||
# Try to infer step from checkpoint folders
|
||||
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
|
||||
if step_folders:
|
||||
steps = [int(d.name.split("_")[1]) for d in step_folders]
|
||||
resume_step = max(steps)
|
||||
print(f"Resuming from step {resume_step}")
|
||||
return resume_step
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
@@ -325,6 +440,17 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||
|
||||
# Save LoRA config and base model path to a separate JSON file
|
||||
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
|
||||
import json
|
||||
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
|
||||
lora_info = {
|
||||
"base_model": base_model_to_save,
|
||||
"lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg),
|
||||
}
|
||||
with open(folder / "lora_config.json", "w", encoding="utf-8") as f:
|
||||
json.dump(lora_info, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
# Full finetune: save non-vae weights to model.safetensors
|
||||
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
|
||||
@@ -345,6 +471,29 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
|
||||
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
|
||||
latest_link = save_dir / "latest"
|
||||
try:
|
||||
if latest_link.exists() or latest_link.is_symlink():
|
||||
# remove existing link or directory
|
||||
if latest_link.is_dir() and not latest_link.is_symlink():
|
||||
shutil.rmtree(latest_link)
|
||||
else:
|
||||
latest_link.unlink()
|
||||
# Create a symlink pointing to the new folder
|
||||
os.symlink(str(folder), str(latest_link))
|
||||
except Exception:
|
||||
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
|
||||
try:
|
||||
if latest_link.exists():
|
||||
if latest_link.is_dir():
|
||||
shutil.rmtree(latest_link)
|
||||
else:
|
||||
latest_link.unlink()
|
||||
shutil.copytree(folder, latest_link)
|
||||
except Exception:
|
||||
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
@@ -359,4 +508,3 @@ if __name__ == "__main__":
|
||||
# Otherwise use command line args (parsed by argbind)
|
||||
with argbind.scope(args):
|
||||
train()
|
||||
|
||||
|
||||
@@ -55,11 +55,12 @@ class VoxCPM:
|
||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||
else:
|
||||
self.denoiser = None
|
||||
print("Warm up VoxCPMModel...")
|
||||
self.tts_model.generate(
|
||||
target_text="Hello, this is the first test sentence.",
|
||||
max_len=10,
|
||||
)
|
||||
if optimize:
|
||||
print("Warm up VoxCPMModel...")
|
||||
self.tts_model.generate(
|
||||
target_text="Hello, this is the first test sentence.",
|
||||
max_len=10,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
|
||||
Reference in New Issue
Block a user