Modify lora inference api
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
Full finetune inference script (no LoRA).
|
||||
|
||||
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
||||
can be loaded directly via VoxCPMModel.from_local().
|
||||
can be loaded directly via VoxCPM.
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -26,9 +26,8 @@ import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.core import VoxCPM
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -81,49 +80,52 @@ def parse_args():
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--normalize",
|
||||
action="store_true",
|
||||
help="Enable text normalization",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory
|
||||
# Load model from checkpoint directory (no denoiser)
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False)
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=args.ckpt_dir,
|
||||
load_denoiser=False,
|
||||
optimize=True,
|
||||
)
|
||||
|
||||
# Run inference
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
|
||||
prompt_text = args.prompt_text if args.prompt_text else None
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
|
||||
# Squeeze and save audio
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio_np = audio.squeeze(0).cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
prompt_text=prompt_text,
|
||||
cfg_value=args.cfg_value,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
max_length=args.max_len,
|
||||
normalize=args.normalize,
|
||||
denoise=False,
|
||||
)
|
||||
|
||||
# Save audio
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.sample_rate)
|
||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user