Merge branch 'dev_1.5'
# Conflicts: # README.md # docs/finetune.md # scripts/test_voxcpm_ft_infer.py # scripts/test_voxcpm_lora_infer.py # src/voxcpm/core.py
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
launch.json
|
launch.json
|
||||||
__pycache__
|
__pycache__
|
||||||
voxcpm.egg-info
|
voxcpm.egg-info
|
||||||
|
.DS_Store
|
||||||
10
README.md
10
README.md
@@ -98,8 +98,8 @@ wav = model.generate(
|
|||||||
prompt_text=None, # optional: reference text
|
prompt_text=None, # optional: reference text
|
||||||
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
cfg_value=2.0, # LM guidance on LocDiT, higher for better adherence to the prompt, but maybe worse
|
||||||
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
inference_timesteps=10, # LocDiT inference timesteps, higher for better result, lower for fast speed
|
||||||
normalize=True, # enable external TN tool, but will disable native raw text support
|
normalize=False, # enable external TN tool, but will disable native raw text support
|
||||||
denoise=True, # enable external Denoise tool, but it may cause some distortion and restrict the sampling rate to 16kHz
|
denoise=False, # enable external Denoise tool, but it may cause some distortion and restrict the sampling rate to 16kHz
|
||||||
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
retry_badcase=True, # enable retrying mode for some bad cases (unstoppable)
|
||||||
retry_badcase_max_times=3, # maximum retrying times
|
retry_badcase_max_times=3, # maximum retrying times
|
||||||
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
|
retry_badcase_ratio_threshold=6.0, # maximum length restriction for bad case detection (simple but effective), it could be adjusted for slow pace speech
|
||||||
@@ -134,14 +134,14 @@ voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, desi
|
|||||||
--prompt-audio path/to/voice.wav \
|
--prompt-audio path/to/voice.wav \
|
||||||
--prompt-text "reference transcript" \
|
--prompt-text "reference transcript" \
|
||||||
--output out.wav \
|
--output out.wav \
|
||||||
--denoise
|
# --denoise
|
||||||
|
|
||||||
# (Optinal) Voice cloning (reference audio + transcript file)
|
# (Optinal) Voice cloning (reference audio + transcript file)
|
||||||
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \
|
||||||
--prompt-audio path/to/voice.wav \
|
--prompt-audio path/to/voice.wav \
|
||||||
--prompt-file "/path/to/text-file" \
|
--prompt-file "/path/to/text-file" \
|
||||||
--output out.wav \
|
--output out.wav \
|
||||||
--denoise
|
# --denoise
|
||||||
|
|
||||||
# 3) Batch processing (one text per line)
|
# 3) Batch processing (one text per line)
|
||||||
voxcpm --input examples/input.txt --output-dir outs
|
voxcpm --input examples/input.txt --output-dir outs
|
||||||
@@ -149,7 +149,7 @@ voxcpm --input examples/input.txt --output-dir outs
|
|||||||
voxcpm --input examples/input.txt --output-dir outs \
|
voxcpm --input examples/input.txt --output-dir outs \
|
||||||
--prompt-audio path/to/voice.wav \
|
--prompt-audio path/to/voice.wav \
|
||||||
--prompt-text "reference transcript" \
|
--prompt-text "reference transcript" \
|
||||||
--denoise
|
# --denoise
|
||||||
|
|
||||||
# 4) Inference parameters (quality/speed)
|
# 4) Inference parameters (quality/speed)
|
||||||
voxcpm --text "..." --output out.wav \
|
voxcpm --text "..." --output out.wav \
|
||||||
|
|||||||
@@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou
|
|||||||
### API Reference
|
### API Reference
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from voxcpm.model import VoxCPMModel
|
from voxcpm.core import VoxCPM
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
|
|
||||||
# 1. Load model with LoRA structure
|
# 1. Load model with LoRA structure and weights
|
||||||
lora_cfg = LoRAConfig(
|
lora_cfg = LoRAConfig(
|
||||||
enable_lm=True,
|
enable_lm=True,
|
||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
@@ -283,15 +283,20 @@ lora_cfg = LoRAConfig(
|
|||||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
model = VoxCPMModel.from_local(
|
model = VoxCPM.from_pretrained(
|
||||||
pretrained_path,
|
hf_model_id="openbmb/VoxCPM1.5", # or local path
|
||||||
|
load_denoiser=False, # Optional: disable denoiser for faster loading
|
||||||
optimize=True, # Enable torch.compile acceleration
|
optimize=True, # Enable torch.compile acceleration
|
||||||
lora_config=lora_cfg
|
lora_config=lora_cfg,
|
||||||
|
lora_weights_path="/path/to/lora_checkpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Load LoRA weights (works after torch.compile)
|
# 2. Generate audio
|
||||||
loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint")
|
audio = model.generate(
|
||||||
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
text="Hello, this is LoRA fine-tuned result.",
|
||||||
|
prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning
|
||||||
|
prompt_text="Reference audio transcript", # Optional: for voice cloning
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Disable LoRA (use base model only)
|
# 3. Disable LoRA (use base model only)
|
||||||
model.set_lora_enabled(False)
|
model.set_lora_enabled(False)
|
||||||
@@ -300,23 +305,39 @@ model.set_lora_enabled(False)
|
|||||||
model.set_lora_enabled(True)
|
model.set_lora_enabled(True)
|
||||||
|
|
||||||
# 5. Unload LoRA (reset weights to zero)
|
# 5. Unload LoRA (reset weights to zero)
|
||||||
model.reset_lora_weights()
|
model.unload_lora()
|
||||||
|
|
||||||
# 6. Hot-swap to another LoRA
|
# 6. Hot-swap to another LoRA
|
||||||
model.load_lora_weights("/path/to/another_lora_checkpoint")
|
loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint")
|
||||||
|
print(f"Loaded {len(loaded)} params, skipped {len(skipped)}")
|
||||||
|
|
||||||
# 7. Get current LoRA weights
|
# 7. Get current LoRA weights
|
||||||
lora_state = model.get_lora_state_dict()
|
lora_state = model.get_lora_state_dict()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Simplified Usage (Auto LoRA Config)
|
||||||
|
|
||||||
|
If you only have LoRA weights and don't need custom config, just provide the path:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
|
# Auto-create default LoRAConfig when only lora_weights_path is provided
|
||||||
|
model = VoxCPM.from_pretrained(
|
||||||
|
hf_model_id="openbmb/VoxCPM1.5",
|
||||||
|
lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Method Reference
|
### Method Reference
|
||||||
|
|
||||||
| Method | Description | torch.compile Compatible |
|
| Method | Description | torch.compile Compatible |
|
||||||
|--------|-------------|--------------------------|
|
|--------|-------------|--------------------------|
|
||||||
| `load_lora_weights(path)` | Load LoRA weights from file | ✅ |
|
| `load_lora(path)` | Load LoRA weights from file | ✅ |
|
||||||
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
|
| `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ |
|
||||||
| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ |
|
| `unload_lora()` | Reset LoRA weights to initial values | ✅ |
|
||||||
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
|
| `get_lora_state_dict()` | Get current LoRA weights | ✅ |
|
||||||
|
| `lora_enabled` | Property: check if LoRA is configured | ✅ |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict()
|
|||||||
### 4. LoRA Not Taking Effect at Inference
|
### 4. LoRA Not Taking Effect at Inference
|
||||||
|
|
||||||
- Ensure inference config matches training config LoRA parameters
|
- Ensure inference config matches training config LoRA parameters
|
||||||
- Check `load_lora_weights` return value - `skipped_keys` should be empty
|
- Check `load_lora()` return value - `skipped_keys` should be empty
|
||||||
- Verify `set_lora_enabled(True)` is called
|
- Verify `set_lora_enabled(True)` is called
|
||||||
|
|
||||||
### 5. Checkpoint Loading Errors
|
### 5. Checkpoint Loading Errors
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
Full finetune inference script (no LoRA).
|
Full finetune inference script (no LoRA).
|
||||||
|
|
||||||
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
||||||
can be loaded directly via VoxCPMModel.from_local().
|
can be loaded directly via VoxCPM.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
@@ -26,9 +26,8 @@ import argparse
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
@@ -81,49 +80,52 @@ def parse_args():
|
|||||||
default=600,
|
default=600,
|
||||||
help="Max generation steps",
|
help="Max generation steps",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--normalize",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable text normalization",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Load model from checkpoint directory
|
# Load model from checkpoint directory (no denoiser)
|
||||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||||
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False)
|
model = VoxCPM.from_pretrained(
|
||||||
|
hf_model_id=args.ckpt_dir,
|
||||||
|
load_denoiser=False,
|
||||||
|
optimize=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
prompt_wav_path = args.prompt_audio or ""
|
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||||
prompt_text = args.prompt_text or ""
|
prompt_text = args.prompt_text if args.prompt_text else None
|
||||||
|
|
||||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||||
if prompt_wav_path:
|
if prompt_wav_path:
|
||||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||||
|
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Squeeze and save audio
|
# Save audio
|
||||||
if isinstance(audio, torch.Tensor):
|
|
||||||
audio_np = audio.squeeze(0).cpu().numpy()
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
|
|
||||||
|
|
||||||
out_path = Path(args.output)
|
out_path = Path(args.output)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
sf.write(str(out_path), audio_np, model.sample_rate)
|
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||||
|
|
||||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,9 +25,8 @@ import argparse
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import torch
|
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel
|
from voxcpm.core import VoxCPM
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
from voxcpm.training.config import load_yaml_config
|
from voxcpm.training.config import load_yaml_config
|
||||||
|
|
||||||
@@ -88,6 +87,11 @@ def parse_args():
|
|||||||
default=600,
|
default=600,
|
||||||
help="Max generation steps",
|
help="Max generation steps",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--normalize",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable text normalization",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -100,123 +104,114 @@ def main():
|
|||||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||||
|
|
||||||
# 2. Load base model (with LoRA structure and torch.compile)
|
# 2. Check LoRA checkpoint
|
||||||
print(f"[1/3] Loading base model: {pretrained_path}")
|
ckpt_dir = args.lora_ckpt
|
||||||
model = VoxCPMModel.from_local(
|
if not Path(ckpt_dir).exists():
|
||||||
pretrained_path,
|
|
||||||
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
|
|
||||||
training=False,
|
|
||||||
lora_config=lora_cfg,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Debug: check DiT param paths after compile
|
|
||||||
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
|
|
||||||
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
|
|
||||||
|
|
||||||
# 3. Load LoRA weights (works after compile)
|
|
||||||
ckpt_dir = Path(args.lora_ckpt)
|
|
||||||
if not ckpt_dir.exists():
|
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||||
|
|
||||||
print(f"[2/3] Loading LoRA weights: {ckpt_dir}")
|
# 3. Load model with LoRA (no denoiser)
|
||||||
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
|
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||||
print(f" Loaded {len(loaded)} parameters")
|
print(f" LoRA weights: {ckpt_dir}")
|
||||||
if skipped:
|
model = VoxCPM.from_pretrained(
|
||||||
print(f"[WARNING] Skipped {len(skipped)} parameters")
|
hf_model_id=pretrained_path,
|
||||||
print(f" Skipped keys (first 5): {skipped[:5]}")
|
load_denoiser=False,
|
||||||
|
optimize=True,
|
||||||
|
lora_config=lora_cfg,
|
||||||
|
lora_weights_path=ckpt_dir,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Synthesize audio
|
# 4. Synthesize audio
|
||||||
prompt_wav_path = args.prompt_audio or ""
|
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||||
prompt_text = args.prompt_text or ""
|
prompt_text = args.prompt_text if args.prompt_text else None
|
||||||
out_path = Path(args.output)
|
out_path = Path(args.output)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"\n[3/3] Starting synthesis tests...")
|
print(f"\n[2/2] Starting synthesis tests...")
|
||||||
|
|
||||||
# === Test 1: With LoRA ===
|
# === Test 1: With LoRA ===
|
||||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
|
||||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||||
sf.write(str(lora_output), audio_np, model.sample_rate)
|
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||||
model.set_lora_enabled(False)
|
model.set_lora_enabled(False)
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
|
||||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||||
sf.write(str(disabled_output), audio_np, model.sample_rate)
|
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
# === Test 3: Re-enable LoRA ===
|
# === Test 3: Re-enable LoRA ===
|
||||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||||
model.set_lora_enabled(True)
|
model.set_lora_enabled(True)
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
|
||||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||||
sf.write(str(reenabled_output), audio_np, model.sample_rate)
|
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||||
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...")
|
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||||
model.reset_lora_weights()
|
model.unload_lora()
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
|
||||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||||
sf.write(str(reset_output), audio_np, model.sample_rate)
|
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
# === Test 5: Hot-reload LoRA (load_lora_weights) ===
|
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...")
|
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||||
loaded, _ = model.load_lora_weights(str(ckpt_dir))
|
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||||
print(f" Reloaded {len(loaded)} parameters")
|
print(f" Reloaded {len(loaded)} parameters")
|
||||||
with torch.inference_mode():
|
audio_np = model.generate(
|
||||||
audio = model.generate(
|
text=args.text,
|
||||||
target_text=args.text,
|
|
||||||
prompt_text=prompt_text,
|
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
max_len=args.max_len,
|
prompt_text=prompt_text,
|
||||||
inference_timesteps=args.inference_timesteps,
|
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
|
inference_timesteps=args.inference_timesteps,
|
||||||
|
max_length=args.max_len,
|
||||||
|
normalize=args.normalize,
|
||||||
|
denoise=False,
|
||||||
)
|
)
|
||||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
|
||||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||||
sf.write(str(reload_output), audio_np, model.sample_rate)
|
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||||
|
|
||||||
print(f"\n[Done] All tests completed!")
|
print(f"\n[Done] All tests completed!")
|
||||||
print(f" - with_lora: {lora_output}")
|
print(f" - with_lora: {lora_output}")
|
||||||
@@ -228,5 +223,3 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
|
|||||||
"ZIPENHANCER_MODEL_PATH", None
|
"ZIPENHANCER_MODEL_PATH", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build LoRA config if lora_path is provided
|
||||||
|
lora_config = None
|
||||||
|
lora_weights_path = getattr(args, "lora_path", None)
|
||||||
|
if lora_weights_path:
|
||||||
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
|
lora_config = LoRAConfig(
|
||||||
|
enable_lm=getattr(args, "lora_enable_lm", True),
|
||||||
|
enable_dit=getattr(args, "lora_enable_dit", True),
|
||||||
|
enable_proj=getattr(args, "lora_enable_proj", False),
|
||||||
|
r=getattr(args, "lora_r", 32),
|
||||||
|
alpha=getattr(args, "lora_alpha", 16),
|
||||||
|
dropout=getattr(args, "lora_dropout", 0.0),
|
||||||
|
)
|
||||||
|
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||||
|
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||||
|
|
||||||
# Load from local path if provided
|
# Load from local path if provided
|
||||||
if getattr(args, "model_path", None):
|
if getattr(args, "model_path", None):
|
||||||
try:
|
try:
|
||||||
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
|
|||||||
voxcpm_model_path=args.model_path,
|
voxcpm_model_path=args.model_path,
|
||||||
zipenhancer_model_path=zipenhancer_path,
|
zipenhancer_model_path=zipenhancer_path,
|
||||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||||
|
lora_config=lora_config,
|
||||||
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (local).")
|
print("Model loaded (local).")
|
||||||
return model
|
return model
|
||||||
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
|
|||||||
zipenhancer_model_id=zipenhancer_path,
|
zipenhancer_model_id=zipenhancer_path,
|
||||||
cache_dir=getattr(args, "cache_dir", None),
|
cache_dir=getattr(args, "cache_dir", None),
|
||||||
local_files_only=getattr(args, "local_files_only", False),
|
local_files_only=getattr(args, "local_files_only", False),
|
||||||
|
lora_config=lora_config,
|
||||||
|
lora_weights_path=lora_weights_path,
|
||||||
)
|
)
|
||||||
print("Model loaded (from_pretrained).")
|
print("Model loaded (from_pretrained).")
|
||||||
return model
|
return model
|
||||||
@@ -256,6 +276,15 @@ Examples:
|
|||||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||||
|
|
||||||
|
# LoRA parameters
|
||||||
|
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
||||||
|
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
||||||
|
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
||||||
|
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
||||||
|
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
||||||
|
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
||||||
|
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Generator
|
from typing import Generator, Optional
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel
|
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||||
|
|
||||||
class VoxCPM:
|
class VoxCPM:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -12,6 +12,8 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
enable_denoiser : bool = True,
|
enable_denoiser : bool = True,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
lora_weights_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize VoxCPM TTS pipeline.
|
"""Initialize VoxCPM TTS pipeline.
|
||||||
|
|
||||||
@@ -23,9 +25,30 @@ class VoxCPM:
|
|||||||
id or local path. If None, denoiser will not be initialized.
|
id or local path. If None, denoiser will not be initialized.
|
||||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||||
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
|
provided without lora_config, a default config will be created.
|
||||||
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}")
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize)
|
|
||||||
|
# If lora_weights_path is provided but no lora_config, create a default one
|
||||||
|
if lora_weights_path is not None and lora_config is None:
|
||||||
|
lora_config = LoRAConfig(
|
||||||
|
enable_lm=True,
|
||||||
|
enable_dit=True,
|
||||||
|
enable_proj=False,
|
||||||
|
)
|
||||||
|
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}")
|
||||||
|
|
||||||
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
|
||||||
|
# Load LoRA weights if path is provided
|
||||||
|
if lora_weights_path is not None:
|
||||||
|
print(f"Loading LoRA weights from: {lora_weights_path}")
|
||||||
|
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
|
||||||
|
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}")
|
||||||
|
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
from .zipenhancer import ZipEnhancer
|
from .zipenhancer import ZipEnhancer
|
||||||
@@ -40,12 +63,14 @@ class VoxCPM:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,
|
def from_pretrained(cls,
|
||||||
hf_model_id: str = "openbmb/VoxCPM-0.5B",
|
hf_model_id: str = "openbmb/VoxCPM1.5",
|
||||||
load_denoiser: bool = True,
|
load_denoiser: bool = True,
|
||||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
cache_dir: str = None,
|
cache_dir: str = None,
|
||||||
local_files_only: bool = False,
|
local_files_only: bool = False,
|
||||||
optimize: bool = True,
|
optimize: bool = True,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
lora_weights_path: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||||
@@ -59,6 +84,12 @@ class VoxCPM:
|
|||||||
cache_dir: Custom cache directory for the snapshot.
|
cache_dir: Custom cache directory for the snapshot.
|
||||||
local_files_only: If True, only use local files and do not attempt
|
local_files_only: If True, only use local files and do not attempt
|
||||||
to download.
|
to download.
|
||||||
|
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||||
|
provided without lora_config, a default config will be created with
|
||||||
|
enable_lm=True and enable_dit=True.
|
||||||
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
|
||||||
|
after model initialization.
|
||||||
Kwargs:
|
Kwargs:
|
||||||
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
Additional keyword arguments passed to the ``VoxCPM`` constructor.
|
||||||
|
|
||||||
@@ -90,6 +121,8 @@ class VoxCPM:
|
|||||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||||
enable_denoiser=load_denoiser,
|
enable_denoiser=load_denoiser,
|
||||||
optimize=optimize,
|
optimize=optimize,
|
||||||
|
lora_config=lora_config,
|
||||||
|
lora_weights_path=lora_weights_path,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,9 +138,10 @@ class VoxCPM:
|
|||||||
prompt_text : str = None,
|
prompt_text : str = None,
|
||||||
cfg_value : float = 2.0,
|
cfg_value : float = 2.0,
|
||||||
inference_timesteps : int = 10,
|
inference_timesteps : int = 10,
|
||||||
max_length : int = 4096,
|
min_len : int = 2,
|
||||||
normalize : bool = True,
|
max_len : int = 4096,
|
||||||
denoise : bool = True,
|
normalize : bool = False,
|
||||||
|
denoise : bool = False,
|
||||||
retry_badcase : bool = True,
|
retry_badcase : bool = True,
|
||||||
retry_badcase_max_times : int = 3,
|
retry_badcase_max_times : int = 3,
|
||||||
retry_badcase_ratio_threshold : float = 6.0,
|
retry_badcase_ratio_threshold : float = 6.0,
|
||||||
@@ -127,7 +161,7 @@ class VoxCPM:
|
|||||||
prompt_text: Text content corresponding to the prompt audio.
|
prompt_text: Text content corresponding to the prompt audio.
|
||||||
cfg_value: Guidance scale for the generation model.
|
cfg_value: Guidance scale for the generation model.
|
||||||
inference_timesteps: Number of inference steps.
|
inference_timesteps: Number of inference steps.
|
||||||
max_length: Maximum token length during generation.
|
max_len: Maximum token length during generation.
|
||||||
normalize: Whether to run text normalization before generation.
|
normalize: Whether to run text normalization before generation.
|
||||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
denoise: Whether to denoise the prompt audio if a denoiser is
|
||||||
available.
|
available.
|
||||||
@@ -177,8 +211,8 @@ class VoxCPM:
|
|||||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||||
target_text=text,
|
target_text=text,
|
||||||
prompt_cache=fixed_prompt_cache,
|
prompt_cache=fixed_prompt_cache,
|
||||||
min_len=2,
|
min_len=min_len,
|
||||||
max_len=max_length,
|
max_len=max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
retry_badcase=retry_badcase,
|
retry_badcase=retry_badcase,
|
||||||
@@ -196,3 +230,51 @@ class VoxCPM:
|
|||||||
os.unlink(temp_prompt_wav_path)
|
os.unlink(temp_prompt_wav_path)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
# LoRA Interface (delegated to VoxCPMModel)
|
||||||
|
# ------------------------------------------------------------------ #
|
||||||
|
def load_lora(self, lora_weights_path: str) -> tuple:
|
||||||
|
"""Load LoRA weights from a checkpoint file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_weights_path: Path to LoRA weights (.pth file or directory
|
||||||
|
containing lora_weights.ckpt).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model was not initialized with LoRA config.
|
||||||
|
"""
|
||||||
|
if self.tts_model.lora_config is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot load LoRA weights: model was not initialized with LoRA config. "
|
||||||
|
"Please reinitialize with lora_config or lora_weights_path parameter."
|
||||||
|
)
|
||||||
|
return self.tts_model.load_lora_weights(lora_weights_path)
|
||||||
|
|
||||||
|
def unload_lora(self):
|
||||||
|
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
|
||||||
|
self.tts_model.reset_lora_weights()
|
||||||
|
|
||||||
|
def set_lora_enabled(self, enabled: bool):
|
||||||
|
"""Enable or disable LoRA layers without unloading weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: If True, LoRA layers are active; if False, only base model is used.
|
||||||
|
"""
|
||||||
|
self.tts_model.set_lora_enabled(enabled)
|
||||||
|
|
||||||
|
def get_lora_state_dict(self) -> dict:
|
||||||
|
"""Get current LoRA parameters state dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: State dict containing all LoRA parameters (lora_A, lora_B).
|
||||||
|
"""
|
||||||
|
return self.tts_model.get_lora_state_dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_enabled(self) -> bool:
|
||||||
|
"""Check if LoRA is currently configured."""
|
||||||
|
return self.tts_model.lora_config is not None
|
||||||
Reference in New Issue
Block a user