add lora funetine webUI; optimize lora save and load logic

This commit is contained in:
刘鑫
2025-12-09 21:34:39 +08:00
parent 0779a93697
commit a266c0a88d
9 changed files with 1575 additions and 48 deletions

View File

@@ -14,6 +14,8 @@ import torch
from tensorboardX import SummaryWriter
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
import signal
import os
try:
from safetensors.torch import save_file
@@ -56,8 +58,16 @@ def train(
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
lora: dict = None,
config_path: str = "",
# Distribution options (for LoRA checkpoints)
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
):
_ = config_path
# Validate distribution options
if lora is not None and distribute and not hf_model_id:
raise ValueError("hf_model_id is required when distribute=True")
accelerator = Accelerator(amp=True)
save_dir = Path(save_path)
@@ -171,6 +181,39 @@ def train(
num_training_steps=total_training_steps,
)
# Try to load checkpoint and resume training
start_step = 0
if accelerator.rank == 0:
start_step = load_checkpoint(model, optimizer, scheduler, save_dir)
# Broadcast start_step to all processes
if hasattr(accelerator, 'all_reduce'):
start_step_tensor = torch.tensor(start_step, device=accelerator.device)
accelerator.all_reduce(start_step_tensor)
start_step = int(start_step_tensor.item())
if start_step > 0 and accelerator.rank == 0:
tracker.print(f"Resuming training from step {start_step}")
# Resume tracker for signal handler to read current step
resume = {"step": start_step}
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume):
try:
cur_step = int(_resume.get("step", start_step))
except Exception:
cur_step = start_step
print(f"Signal {signum} received. Saving checkpoint at step {cur_step} ...")
try:
save_checkpoint(_model, _optim, _sched, _save_dir, cur_step, _pretrained, _hf_id, _dist)
print("Checkpoint saved. Exiting.")
except Exception as e:
print(f"Error saving checkpoint on signal: {e}")
os._exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
grad_accum_steps = max(int(grad_accum_steps), 1)
data_epoch = 0
@@ -191,7 +234,9 @@ def train(
return next(train_iter)
with tracker.live():
for step in range(num_iters):
for step in range(start_step, num_iters):
# update resume step so signal handler can save current progress
resume["step"] = step
tracker.step = step
optimizer.zero_grad(set_to_none=True)
@@ -255,10 +300,10 @@ def train(
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
if step % save_interval == 0 and accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
if accelerator.rank == 0:
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path, hf_model_id, distribute)
if writer:
writer.close()
@@ -301,7 +346,77 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
model.train()
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
def load_checkpoint(model, optimizer, scheduler, save_dir: Path):
"""
Load the latest checkpoint if it exists.
Returns the step number to resume from, or 0 if no checkpoint found.
"""
latest_folder = save_dir / "latest"
if not latest_folder.exists():
return 0
unwrapped = model.module if hasattr(model, "module") else model
lora_cfg = unwrapped.lora_config
# Load model weights
if lora_cfg is not None:
# LoRA: load lora_weights
lora_weights_path = latest_folder / "lora_weights.safetensors"
if not lora_weights_path.exists():
lora_weights_path = latest_folder / "lora_weights.ckpt"
if lora_weights_path.exists():
if lora_weights_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(lora_weights_path))
else:
ckpt = torch.load(lora_weights_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
# Load only lora weights
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded LoRA weights from {lora_weights_path}")
else:
# Full finetune: load model.safetensors or pytorch_model.bin
model_path = latest_folder / "model.safetensors"
if not model_path.exists():
model_path = latest_folder / "pytorch_model.bin"
if model_path.exists():
if model_path.suffix == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(str(model_path))
else:
ckpt = torch.load(model_path, map_location="cpu")
state_dict = ckpt.get("state_dict", ckpt)
unwrapped.load_state_dict(state_dict, strict=False)
print(f"Loaded model weights from {model_path}")
# Load optimizer state
optimizer_path = latest_folder / "optimizer.pth"
if optimizer_path.exists():
optimizer.load_state_dict(torch.load(optimizer_path, map_location="cpu"))
print(f"Loaded optimizer state from {optimizer_path}")
# Load scheduler state
scheduler_path = latest_folder / "scheduler.pth"
if scheduler_path.exists():
scheduler.load_state_dict(torch.load(scheduler_path, map_location="cpu"))
print(f"Loaded scheduler state from {scheduler_path}")
# Try to infer step from checkpoint folders
step_folders = [d for d in save_dir.iterdir() if d.is_dir() and d.name.startswith("step_")]
if step_folders:
steps = [int(d.name.split("_")[1]) for d in step_folders]
resume_step = max(steps)
print(f"Resuming from step {resume_step}")
return resume_step
return 0
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
"""
Save checkpoint with different strategies for full finetune vs LoRA:
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
@@ -325,6 +440,17 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
save_file(state_dict, folder / "lora_weights.safetensors")
else:
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
# Save LoRA config and base model path to a separate JSON file
# If distribute=True, save hf_model_id; otherwise save local pretrained_path
import json
base_model_to_save = hf_model_id if distribute else (str(pretrained_path) if pretrained_path else None)
lora_info = {
"base_model": base_model_to_save,
"lora_config": lora_cfg.model_dump() if hasattr(lora_cfg, "model_dump") else vars(lora_cfg),
}
with open(folder / "lora_config.json", "w", encoding="utf-8") as f:
json.dump(lora_info, f, indent=2, ensure_ascii=False)
else:
# Full finetune: save non-vae weights to model.safetensors
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
@@ -345,6 +471,29 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
# Update (or create) a `latest` symlink pointing to the most recent checkpoint folder
latest_link = save_dir / "latest"
try:
if latest_link.exists() or latest_link.is_symlink():
# remove existing link or directory
if latest_link.is_dir() and not latest_link.is_symlink():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
# Create a symlink pointing to the new folder
os.symlink(str(folder), str(latest_link))
except Exception:
# If symlink creation fails (e.g., on Windows or permission issues), fall back to copying
try:
if latest_link.exists():
if latest_link.is_dir():
shutil.rmtree(latest_link)
else:
latest_link.unlink()
shutil.copytree(folder, latest_link)
except Exception:
print(f"Warning: failed to update latest checkpoint link at {latest_link}")
if __name__ == "__main__":
from voxcpm.training.config import load_yaml_config
@@ -358,5 +507,4 @@ if __name__ == "__main__":
else:
# Otherwise use command line args (parsed by argbind)
with argbind.scope(args):
train()
train()