From 400f47a516f8869b4c0a04a4857d9fe7057a04b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=91=AB?= Date: Fri, 5 Dec 2025 22:22:13 +0800 Subject: [PATCH] Modify lora inference api --- docs/finetune.md | 49 +++++--- scripts/test_voxcpm_ft_infer.py | 54 ++++----- scripts/test_voxcpm_lora_infer.py | 183 ++++++++++++++---------------- src/voxcpm/cli.py | 29 +++++ src/voxcpm/core.py | 89 ++++++++++++++- 5 files changed, 265 insertions(+), 139 deletions(-) diff --git a/docs/finetune.md b/docs/finetune.md index 92b5e03..70a24a3 100644 --- a/docs/finetune.md +++ b/docs/finetune.md @@ -271,10 +271,10 @@ LoRA supports dynamic loading, unloading, and switching at inference time withou ### API Reference ```python -from voxcpm.model import VoxCPMModel +from voxcpm.core import VoxCPM from voxcpm.model.voxcpm import LoRAConfig -# 1. Load model with LoRA structure +# 1. Load model with LoRA structure and weights lora_cfg = LoRAConfig( enable_lm=True, enable_dit=True, @@ -283,15 +283,20 @@ lora_cfg = LoRAConfig( target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"], target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"], ) -model = VoxCPMModel.from_local( - pretrained_path, - optimize=True, # Enable torch.compile acceleration - lora_config=lora_cfg +model = VoxCPM.from_pretrained( + hf_model_id="openbmb/VoxCPM1.5", # or local path + load_denoiser=False, # Optional: disable denoiser for faster loading + optimize=True, # Enable torch.compile acceleration + lora_config=lora_cfg, + lora_weights_path="/path/to/lora_checkpoint", ) -# 2. Load LoRA weights (works after torch.compile) -loaded, skipped = model.load_lora_weights("/path/to/lora_checkpoint") -print(f"Loaded {len(loaded)} params, skipped {len(skipped)}") +# 2. Generate audio +audio = model.generate( + text="Hello, this is LoRA fine-tuned result.", + prompt_wav_path="/path/to/reference.wav", # Optional: for voice cloning + prompt_text="Reference audio transcript", # Optional: for voice cloning +) # 3. Disable LoRA (use base model only) model.set_lora_enabled(False) @@ -300,23 +305,39 @@ model.set_lora_enabled(False) model.set_lora_enabled(True) # 5. Unload LoRA (reset weights to zero) -model.reset_lora_weights() +model.unload_lora() # 6. Hot-swap to another LoRA -model.load_lora_weights("/path/to/another_lora_checkpoint") +loaded, skipped = model.load_lora("/path/to/another_lora_checkpoint") +print(f"Loaded {len(loaded)} params, skipped {len(skipped)}") # 7. Get current LoRA weights lora_state = model.get_lora_state_dict() ``` +### Simplified Usage (Auto LoRA Config) + +If you only have LoRA weights and don't need custom config, just provide the path: + +```python +from voxcpm.core import VoxCPM + +# Auto-create default LoRAConfig when only lora_weights_path is provided +model = VoxCPM.from_pretrained( + hf_model_id="openbmb/VoxCPM1.5", + lora_weights_path="/path/to/lora_checkpoint", # Will auto-create LoRAConfig +) +``` + ### Method Reference | Method | Description | torch.compile Compatible | |--------|-------------|--------------------------| -| `load_lora_weights(path)` | Load LoRA weights from file | ✅ | +| `load_lora(path)` | Load LoRA weights from file | ✅ | | `set_lora_enabled(bool)` | Enable/disable LoRA | ✅ | -| `reset_lora_weights()` | Reset LoRA weights to initial values | ✅ | +| `unload_lora()` | Reset LoRA weights to initial values | ✅ | | `get_lora_state_dict()` | Get current LoRA weights | ✅ | +| `lora_enabled` | Property: check if LoRA is configured | ✅ | --- @@ -346,7 +367,7 @@ lora_state = model.get_lora_state_dict() ### 4. LoRA Not Taking Effect at Inference - Ensure inference config matches training config LoRA parameters -- Check `load_lora_weights` return value - `skipped_keys` should be empty +- Check `load_lora()` return value - `skipped_keys` should be empty - Verify `set_lora_enabled(True)` is called ### 5. Checkpoint Loading Errors diff --git a/scripts/test_voxcpm_ft_infer.py b/scripts/test_voxcpm_ft_infer.py index 3484a55..9b2c9e1 100644 --- a/scripts/test_voxcpm_ft_infer.py +++ b/scripts/test_voxcpm_ft_infer.py @@ -3,7 +3,7 @@ Full finetune inference script (no LoRA). Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.), -can be loaded directly via VoxCPMModel.from_local(). +can be loaded directly via VoxCPM. Usage: @@ -26,9 +26,8 @@ import argparse from pathlib import Path import soundfile as sf -import torch -from voxcpm.model import VoxCPMModel +from voxcpm.core import VoxCPM def parse_args(): @@ -81,49 +80,52 @@ def parse_args(): default=600, help="Max generation steps", ) + parser.add_argument( + "--normalize", + action="store_true", + help="Enable text normalization", + ) return parser.parse_args() def main(): args = parse_args() - # Load model from checkpoint directory + # Load model from checkpoint directory (no denoiser) print(f"[FT Inference] Loading model: {args.ckpt_dir}") - model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False) + model = VoxCPM.from_pretrained( + hf_model_id=args.ckpt_dir, + load_denoiser=False, + optimize=True, + ) # Run inference - prompt_wav_path = args.prompt_audio or "" - prompt_text = args.prompt_text or "" + prompt_wav_path = args.prompt_audio if args.prompt_audio else None + prompt_text = args.prompt_text if args.prompt_text else None print(f"[FT Inference] Synthesizing: text='{args.text}'") if prompt_wav_path: print(f"[FT Inference] Using reference audio: {prompt_wav_path}") print(f"[FT Inference] Reference text: {prompt_text}") - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - - # Squeeze and save audio - if isinstance(audio, torch.Tensor): - audio_np = audio.squeeze(0).cpu().numpy() - else: - raise TypeError(f"Unexpected return type from model.generate: {type(audio)}") + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) + # Save audio out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) - sf.write(str(out_path), audio_np, model.sample_rate) + sf.write(str(out_path), audio_np, model.tts_model.sample_rate) - print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s") + print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") if __name__ == "__main__": main() - - diff --git a/scripts/test_voxcpm_lora_infer.py b/scripts/test_voxcpm_lora_infer.py index 7f0bf55..47a5a86 100644 --- a/scripts/test_voxcpm_lora_infer.py +++ b/scripts/test_voxcpm_lora_infer.py @@ -25,9 +25,8 @@ import argparse from pathlib import Path import soundfile as sf -import torch -from voxcpm.model import VoxCPMModel +from voxcpm.core import VoxCPM from voxcpm.model.voxcpm import LoRAConfig from voxcpm.training.config import load_yaml_config @@ -88,6 +87,11 @@ def parse_args(): default=600, help="Max generation steps", ) + parser.add_argument( + "--normalize", + action="store_true", + help="Enable text normalization", + ) return parser.parse_args() @@ -100,123 +104,114 @@ def main(): lora_cfg_dict = cfg.get("lora", {}) or {} lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None - # 2. Load base model (with LoRA structure and torch.compile) - print(f"[1/3] Loading base model: {pretrained_path}") - model = VoxCPMModel.from_local( - pretrained_path, - optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility - training=False, - lora_config=lora_cfg, - ) - - # Debug: check DiT param paths after compile - dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n] - print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}") - - # 3. Load LoRA weights (works after compile) - ckpt_dir = Path(args.lora_ckpt) - if not ckpt_dir.exists(): + # 2. Check LoRA checkpoint + ckpt_dir = args.lora_ckpt + if not Path(ckpt_dir).exists(): raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}") - - print(f"[2/3] Loading LoRA weights: {ckpt_dir}") - loaded, skipped = model.load_lora_weights(str(ckpt_dir)) - print(f" Loaded {len(loaded)} parameters") - if skipped: - print(f"[WARNING] Skipped {len(skipped)} parameters") - print(f" Skipped keys (first 5): {skipped[:5]}") + + # 3. Load model with LoRA (no denoiser) + print(f"[1/2] Loading model with LoRA: {pretrained_path}") + print(f" LoRA weights: {ckpt_dir}") + model = VoxCPM.from_pretrained( + hf_model_id=pretrained_path, + load_denoiser=False, + optimize=True, + lora_config=lora_cfg, + lora_weights_path=ckpt_dir, + ) # 4. Synthesize audio - prompt_wav_path = args.prompt_audio or "" - prompt_text = args.prompt_text or "" + prompt_wav_path = args.prompt_audio if args.prompt_audio else None + prompt_text = args.prompt_text if args.prompt_text else None out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) - print(f"\n[3/3] Starting synthesis tests...") + print(f"\n[2/2] Starting synthesis tests...") # === Test 1: With LoRA === print(f"\n [Test 1] Synthesize with LoRA...") - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) lora_output = out_path.with_stem(out_path.stem + "_with_lora") - sf.write(str(lora_output), audio_np, model.sample_rate) - print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") + sf.write(str(lora_output), audio_np, model.tts_model.sample_rate) + print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") # === Test 2: Disable LoRA (via set_lora_enabled) === print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...") model.set_lora_enabled(False) - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled") - sf.write(str(disabled_output), audio_np, model.sample_rate) - print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") + sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate) + print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") # === Test 3: Re-enable LoRA === print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...") model.set_lora_enabled(True) - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled") - sf.write(str(reenabled_output), audio_np, model.sample_rate) - print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") + sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate) + print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") # === Test 4: Unload LoRA (reset_lora_weights) === - print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...") - model.reset_lora_weights() - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() + print(f"\n [Test 4] Unload LoRA (unload_lora)...") + model.unload_lora() + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) reset_output = out_path.with_stem(out_path.stem + "_lora_reset") - sf.write(str(reset_output), audio_np, model.sample_rate) - print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") + sf.write(str(reset_output), audio_np, model.tts_model.sample_rate) + print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") - # === Test 5: Hot-reload LoRA (load_lora_weights) === - print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...") - loaded, _ = model.load_lora_weights(str(ckpt_dir)) + # === Test 5: Hot-reload LoRA (load_lora) === + print(f"\n [Test 5] Hot-reload LoRA (load_lora)...") + loaded, skipped = model.load_lora(str(ckpt_dir)) print(f" Reloaded {len(loaded)} parameters") - with torch.inference_mode(): - audio = model.generate( - target_text=args.text, - prompt_text=prompt_text, - prompt_wav_path=prompt_wav_path, - max_len=args.max_len, - inference_timesteps=args.inference_timesteps, - cfg_value=args.cfg_value, - ) - audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy() + audio_np = model.generate( + text=args.text, + prompt_wav_path=prompt_wav_path, + prompt_text=prompt_text, + cfg_value=args.cfg_value, + inference_timesteps=args.inference_timesteps, + max_length=args.max_len, + normalize=args.normalize, + denoise=False, + ) reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded") - sf.write(str(reload_output), audio_np, model.sample_rate) - print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s") + sf.write(str(reload_output), audio_np, model.tts_model.sample_rate) + print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s") print(f"\n[Done] All tests completed!") print(f" - with_lora: {lora_output}") @@ -228,5 +223,3 @@ def main(): if __name__ == "__main__": main() - - diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index 2d5f1a4..c724664 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM: "ZIPENHANCER_MODEL_PATH", None ) + # Build LoRA config if lora_path is provided + lora_config = None + lora_weights_path = getattr(args, "lora_path", None) + if lora_weights_path: + from voxcpm.model.voxcpm import LoRAConfig + lora_config = LoRAConfig( + enable_lm=getattr(args, "lora_enable_lm", True), + enable_dit=getattr(args, "lora_enable_dit", True), + enable_proj=getattr(args, "lora_enable_proj", False), + r=getattr(args, "lora_r", 32), + alpha=getattr(args, "lora_alpha", 16), + dropout=getattr(args, "lora_dropout", 0.0), + ) + print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, " + f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}") + # Load from local path if provided if getattr(args, "model_path", None): try: @@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM: voxcpm_model_path=args.model_path, zipenhancer_model_path=zipenhancer_path, enable_denoiser=not getattr(args, "no_denoiser", False), + lora_config=lora_config, + lora_weights_path=lora_weights_path, ) print("Model loaded (local).") return model @@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM: zipenhancer_model_id=zipenhancer_path, cache_dir=getattr(args, "cache_dir", None), local_files_only=getattr(args, "local_files_only", False), + lora_config=lora_config, + lora_weights_path=lora_weights_path, ) print("Model loaded (from_pretrained).") return model @@ -256,6 +276,15 @@ Examples: parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)") + # LoRA parameters + parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)") + parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)") + parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)") + parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)") + parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)") + parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)") + parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)") + return parser diff --git a/src/voxcpm/core.py b/src/voxcpm/core.py index 514e50e..a2c4290 100644 --- a/src/voxcpm/core.py +++ b/src/voxcpm/core.py @@ -2,9 +2,9 @@ import os import re import tempfile import numpy as np -from typing import Generator +from typing import Generator, Optional from huggingface_hub import snapshot_download -from .model.voxcpm import VoxCPMModel +from .model.voxcpm import VoxCPMModel, LoRAConfig class VoxCPM: def __init__(self, @@ -12,6 +12,8 @@ class VoxCPM: zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base", enable_denoiser : bool = True, optimize: bool = True, + lora_config: Optional[LoRAConfig] = None, + lora_weights_path: Optional[str] = None, ): """Initialize VoxCPM TTS pipeline. @@ -23,9 +25,30 @@ class VoxCPM: id or local path. If None, denoiser will not be initialized. enable_denoiser: Whether to initialize the denoiser pipeline. optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging. + lora_config: LoRA configuration for fine-tuning. If lora_weights_path is + provided without lora_config, a default config will be created. + lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory + containing lora_weights.ckpt). If provided, LoRA weights will be loaded. """ print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}") - self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize) + + # If lora_weights_path is provided but no lora_config, create a default one + if lora_weights_path is not None and lora_config is None: + lora_config = LoRAConfig( + enable_lm=True, + enable_dit=True, + enable_proj=False, + ) + print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}") + + self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config) + + # Load LoRA weights if path is provided + if lora_weights_path is not None: + print(f"Loading LoRA weights from: {lora_weights_path}") + loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path) + print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}") + self.text_normalizer = None if enable_denoiser and zipenhancer_model_path is not None: from .zipenhancer import ZipEnhancer @@ -46,6 +69,8 @@ class VoxCPM: cache_dir: str = None, local_files_only: bool = False, optimize: bool = True, + lora_config: Optional[LoRAConfig] = None, + lora_weights_path: Optional[str] = None, **kwargs, ): """Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot. @@ -59,6 +84,12 @@ class VoxCPM: cache_dir: Custom cache directory for the snapshot. local_files_only: If True, only use local files and do not attempt to download. + lora_config: LoRA configuration for fine-tuning. If lora_weights_path is + provided without lora_config, a default config will be created with + enable_lm=True and enable_dit=True. + lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory + containing lora_weights.ckpt). If provided, LoRA weights will be loaded + after model initialization. Kwargs: Additional keyword arguments passed to the ``VoxCPM`` constructor. @@ -90,6 +121,8 @@ class VoxCPM: zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None, enable_denoiser=load_denoiser, optimize=optimize, + lora_config=lora_config, + lora_weights_path=lora_weights_path, **kwargs, ) @@ -196,4 +229,52 @@ class VoxCPM: try: os.unlink(temp_prompt_wav_path) except OSError: - pass \ No newline at end of file + pass + + # ------------------------------------------------------------------ # + # LoRA Interface (delegated to VoxCPMModel) + # ------------------------------------------------------------------ # + def load_lora(self, lora_weights_path: str) -> tuple: + """Load LoRA weights from a checkpoint file. + + Args: + lora_weights_path: Path to LoRA weights (.pth file or directory + containing lora_weights.ckpt). + + Returns: + tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names. + + Raises: + RuntimeError: If model was not initialized with LoRA config. + """ + if self.tts_model.lora_config is None: + raise RuntimeError( + "Cannot load LoRA weights: model was not initialized with LoRA config. " + "Please reinitialize with lora_config or lora_weights_path parameter." + ) + return self.tts_model.load_lora_weights(lora_weights_path) + + def unload_lora(self): + """Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA).""" + self.tts_model.reset_lora_weights() + + def set_lora_enabled(self, enabled: bool): + """Enable or disable LoRA layers without unloading weights. + + Args: + enabled: If True, LoRA layers are active; if False, only base model is used. + """ + self.tts_model.set_lora_enabled(enabled) + + def get_lora_state_dict(self) -> dict: + """Get current LoRA parameters state dict. + + Returns: + dict: State dict containing all LoRA parameters (lora_A, lora_B). + """ + return self.tts_model.get_lora_state_dict() + + @property + def lora_enabled(self) -> bool: + """Check if LoRA is currently configured.""" + return self.tts_model.lora_config is not None \ No newline at end of file