mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 03:48:12 +00:00
Modify lora inference api
This commit is contained in:
@@ -25,9 +25,8 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
@@ -88,6 +87,11 @@ def parse_args():
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -100,123 +104,114 @@ def main():
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Load base model (with LoRA structure and torch.compile)
|
||||
print(f"[1/3] Loading base model: {pretrained_path}")
|
||||
model = VoxCPMModel.from_local(
|
||||
pretrained_path,
|
||||
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
|
||||
training=False,
|
||||
lora_config=lora_cfg,
|
||||
)
|
||||
|
||||
# Debug: check DiT param paths after compile
|
||||
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
|
||||
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
|
||||
|
||||
# 3. Load LoRA weights (works after compile)
|
||||
ckpt_dir = Path(args.lora_ckpt)
|
||||
if not ckpt_dir.exists():
|
||||
# 2. Check LoRA checkpoint
|
||||
ckpt_dir = args.lora_ckpt
|
||||
if not Path(ckpt_dir).exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
print(f"[2/3] Loading LoRA weights: {ckpt_dir}")
|
||||
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
|
||||
print(f" Loaded {len(loaded)} parameters")
|
||||
if skipped:
|
||||
print(f"[WARNING] Skipped {len(skipped)} parameters")
|
||||
print(f" Skipped keys (first 5): {skipped[:5]}")
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"[1/2] Loading model with LoRA: {pretrained_path}")
|
||||
print(f" LoRA weights: {ckpt_dir}")
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=pretrained_path,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
lora_config=lora_cfg,
|
||||
lora_weights_path=ckpt_dir,
|
||||
)
|
||||
|
||||
# 4. Synthesize audio
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[3/3] Starting synthesis tests...")
|
||||
print(f"\n[2/2] Starting synthesis tests...")
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
model.set_lora_enabled(False)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
model.set_lora_enabled(True)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...")
|
||||
model.reset_lora_weights()
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...")
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora_weights) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...")
|
||||
loaded, _ = model.load_lora_weights(str(ckpt_dir))
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...")
|
||||
loaded, skipped = model.load_lora(str(ckpt_dir))
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
@@ -228,5 +223,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user