mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 03:48:12 +00:00
Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
129
scripts/test_voxcpm_ft_infer.py
Normal file
129
scripts/test_voxcpm_ft_infer.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Full finetune inference script (no LoRA).
|
||||
|
||||
Checkpoint directory contains complete model files (pytorch_model.bin, config.json, audiovae.pth, etc.),
|
||||
can be loaded directly via VoxCPMModel.from_local().
|
||||
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/step_0001000 \
|
||||
--text "Hello, I am the finetuned VoxCPM." \
|
||||
--output ft_test.wav
|
||||
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_ft_infer.py \
|
||||
--ckpt_dir /path/to/checkpoints/step_0001000 \
|
||||
--text "Hello, this is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output ft_clone.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM full-finetune inference test (no LoRA)")
|
||||
parser.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Checkpoint directory (contains pytorch_model.bin, config.json, audiovae.pth, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target text to synthesize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_audio",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: reference audio path for voice cloning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: transcript of reference audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="ft_test.wav",
|
||||
help="Output wav file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_value",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="CFG scale (default: 2.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference_timesteps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Diffusion inference steps (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Load model from checkpoint directory
|
||||
print(f"[FT Inference] Loading model: {args.ckpt_dir}")
|
||||
model = VoxCPMModel.from_local(args.ckpt_dir, optimize=True, training=False)
|
||||
|
||||
# Run inference
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
|
||||
print(f"[FT Inference] Synthesizing: text='{args.text}'")
|
||||
if prompt_wav_path:
|
||||
print(f"[FT Inference] Using reference audio: {prompt_wav_path}")
|
||||
print(f"[FT Inference] Reference text: {prompt_text}")
|
||||
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
|
||||
# Squeeze and save audio
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio_np = audio.squeeze(0).cpu().numpy()
|
||||
else:
|
||||
raise TypeError(f"Unexpected return type from model.generate: {type(audio)}")
|
||||
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out_path), audio_np, model.sample_rate)
|
||||
|
||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
232
scripts/test_voxcpm_lora_infer.py
Normal file
232
scripts/test_voxcpm_lora_infer.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LoRA inference test script.
|
||||
|
||||
Usage:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "Hello, this is LoRA finetuned result." \
|
||||
--output lora_test.wav
|
||||
|
||||
With voice cloning:
|
||||
|
||||
python scripts/test_voxcpm_lora_infer.py \
|
||||
--config_path conf/voxcpm/voxcpm_finetune_test.yaml \
|
||||
--lora_ckpt checkpoints/step_0002000 \
|
||||
--text "This is voice cloning result." \
|
||||
--prompt_audio path/to/ref.wav \
|
||||
--prompt_text "Reference audio transcript" \
|
||||
--output lora_clone.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training YAML config path (contains pretrained_path and lora config)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_ckpt",
|
||||
type=str,
|
||||
required=True,
|
||||
help="LoRA checkpoint directory (contains lora_weights.ckpt with lora_A/lora_B only)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Target text to synthesize",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_audio",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: reference audio path for voice cloning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_text",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional: transcript of reference audio",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="lora_test.wav",
|
||||
help="Output wav file path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfg_value",
|
||||
type=float,
|
||||
default=2.0,
|
||||
help="CFG scale (default: 2.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--inference_timesteps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Diffusion inference steps (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max generation steps",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# 1. Load YAML config
|
||||
cfg = load_yaml_config(args.config_path)
|
||||
pretrained_path = cfg["pretrained_path"]
|
||||
lora_cfg_dict = cfg.get("lora", {}) or {}
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
# 2. Load base model (with LoRA structure and torch.compile)
|
||||
print(f"[1/3] Loading base model: {pretrained_path}")
|
||||
model = VoxCPMModel.from_local(
|
||||
pretrained_path,
|
||||
optimize=True, # compile first, load_lora_weights uses named_parameters for compatibility
|
||||
training=False,
|
||||
lora_config=lora_cfg,
|
||||
)
|
||||
|
||||
# Debug: check DiT param paths after compile
|
||||
dit_params = [n for n, _ in model.named_parameters() if 'feat_decoder' in n and 'lora' in n]
|
||||
print(f"[DEBUG] DiT LoRA param paths after compile (first 3): {dit_params[:3]}")
|
||||
|
||||
# 3. Load LoRA weights (works after compile)
|
||||
ckpt_dir = Path(args.lora_ckpt)
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
|
||||
|
||||
print(f"[2/3] Loading LoRA weights: {ckpt_dir}")
|
||||
loaded, skipped = model.load_lora_weights(str(ckpt_dir))
|
||||
print(f" Loaded {len(loaded)} parameters")
|
||||
if skipped:
|
||||
print(f"[WARNING] Skipped {len(skipped)} parameters")
|
||||
print(f" Skipped keys (first 5): {skipped[:5]}")
|
||||
|
||||
# 4. Synthesize audio
|
||||
prompt_wav_path = args.prompt_audio or ""
|
||||
prompt_text = args.prompt_text or ""
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[3/3] Starting synthesis tests...")
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...")
|
||||
model.set_lora_enabled(False)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...")
|
||||
model.set_lora_enabled(True)
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (reset_lora_weights)...")
|
||||
model.reset_lora_weights()
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora_weights) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora_weights)...")
|
||||
loaded, _ = model.load_lora_weights(str(ckpt_dir))
|
||||
print(f" Reloaded {len(loaded)} parameters")
|
||||
with torch.inference_mode():
|
||||
audio = model.generate(
|
||||
target_text=args.text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
max_len=args.max_len,
|
||||
inference_timesteps=args.inference_timesteps,
|
||||
cfg_value=args.cfg_value,
|
||||
)
|
||||
audio_np = audio.squeeze(0).cpu().numpy() if audio.dim() > 1 else audio.cpu().numpy()
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.sample_rate:.2f}s")
|
||||
|
||||
print(f"\n[Done] All tests completed!")
|
||||
print(f" - with_lora: {lora_output}")
|
||||
print(f" - lora_disabled: {disabled_output}")
|
||||
print(f" - lora_reenabled: {reenabled_output}")
|
||||
print(f" - lora_reset: {reset_output}")
|
||||
print(f" - lora_reloaded: {reload_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
362
scripts/train_voxcpm_finetune.py
Normal file
362
scripts/train_voxcpm_finetune.py
Normal file
@@ -0,0 +1,362 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root / "src"))
|
||||
|
||||
import contextlib
|
||||
from typing import Dict, Optional
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.optim import AdamW
|
||||
from transformers import get_cosine_schedule_with_warmup
|
||||
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
print("Warning: safetensors not available, will use pytorch format")
|
||||
|
||||
from voxcpm.model import VoxCPMModel
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
BatchProcessor,
|
||||
TrainingTracker,
|
||||
build_dataloader,
|
||||
load_audio_text_datasets,
|
||||
)
|
||||
|
||||
|
||||
@argbind.bind(without_prefix=True)
|
||||
def train(
|
||||
pretrained_path: str,
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
sample_rate: int = 16_000,
|
||||
batch_size: int = 1,
|
||||
grad_accum_steps: int = 1,
|
||||
num_workers: int = 2,
|
||||
num_iters: int = 100_000,
|
||||
log_interval: int = 100,
|
||||
valid_interval: int = 1_000,
|
||||
save_interval: int = 10_000,
|
||||
learning_rate: float = 1e-4,
|
||||
weight_decay: float = 1e-2,
|
||||
warmup_steps: int = 1_000,
|
||||
max_steps: int = 100_000,
|
||||
max_batch_tokens: int = 0,
|
||||
save_path: str = "checkpoints",
|
||||
tensorboard: str = "",
|
||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
):
|
||||
_ = config_path
|
||||
accelerator = Accelerator(amp=True)
|
||||
|
||||
save_dir = Path(save_path)
|
||||
tb_dir = Path(tensorboard) if tensorboard else save_dir / "logs"
|
||||
|
||||
# Only create directories on rank 0 to avoid race conditions
|
||||
if accelerator.rank == 0:
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tb_dir.mkdir(parents=True, exist_ok=True)
|
||||
accelerator.barrier() # Wait for directory creation
|
||||
|
||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||
|
||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
train_manifest=train_manifest,
|
||||
val_manifest=val_manifest,
|
||||
sample_rate=sample_rate,
|
||||
)
|
||||
|
||||
def tokenize(batch):
|
||||
text_list = batch["text"]
|
||||
text_ids = [tokenizer(text) for text in text_list]
|
||||
return {"text_ids": text_ids}
|
||||
|
||||
train_ds = train_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
if val_ds is not None:
|
||||
val_ds = val_ds.map(tokenize, batched=True, remove_columns=["text"])
|
||||
|
||||
dataset_cnt = int(max(train_ds["dataset_id"])) + 1 if "dataset_id" in train_ds.column_names else 1
|
||||
num_train_samples = len(train_ds)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Optional: filter samples by estimated token count to avoid OOM
|
||||
# Enabled when max_batch_tokens > 0:
|
||||
# max_sample_len = max_batch_tokens // batch_size
|
||||
# Samples exceeding this length will be dropped
|
||||
# ------------------------------------------------------------------ #
|
||||
if max_batch_tokens and max_batch_tokens > 0:
|
||||
from voxcpm.training.data import compute_sample_lengths
|
||||
|
||||
audio_vae_fps = base_model.audio_vae.sample_rate / base_model.audio_vae.hop_length
|
||||
est_lengths = compute_sample_lengths(
|
||||
train_ds,
|
||||
audio_vae_fps=audio_vae_fps,
|
||||
patch_size=base_model.config.patch_size,
|
||||
)
|
||||
max_sample_len = max_batch_tokens // batch_size if batch_size > 0 else max(est_lengths)
|
||||
keep_indices = [i for i, L in enumerate(est_lengths) if L <= max_sample_len]
|
||||
|
||||
if len(keep_indices) < len(train_ds) and accelerator.rank == 0:
|
||||
tracker.print(
|
||||
f"Filtering {len(train_ds) - len(keep_indices)} / {len(train_ds)} "
|
||||
f"training samples longer than {max_sample_len} tokens "
|
||||
f"(max_batch_tokens={max_batch_tokens})."
|
||||
)
|
||||
train_ds = train_ds.select(keep_indices)
|
||||
|
||||
train_loader = build_dataloader(
|
||||
train_ds,
|
||||
accelerator=accelerator,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=True,
|
||||
)
|
||||
val_loader = (
|
||||
build_dataloader(
|
||||
val_ds,
|
||||
accelerator=accelerator,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
if val_ds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
batch_processor = BatchProcessor(
|
||||
config=base_model.config,
|
||||
audio_vae=base_model.audio_vae,
|
||||
dataset_cnt=dataset_cnt,
|
||||
device=accelerator.device,
|
||||
)
|
||||
del base_model.audio_vae
|
||||
model = accelerator.prepare_model(base_model)
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
unwrapped_model.train()
|
||||
|
||||
|
||||
# Only print param info on rank 0 to avoid cluttered output
|
||||
if accelerator.rank == 0:
|
||||
for name, param in model.named_parameters():
|
||||
print(name, param.requires_grad)
|
||||
|
||||
optimizer = AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
lr=learning_rate,
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
|
||||
# Cosine + warmup scheduler from transformers:
|
||||
# - num_warmup_steps: warmup steps
|
||||
# - num_training_steps: total training steps (outer step count)
|
||||
total_training_steps = max_steps if max_steps > 0 else num_iters
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=warmup_steps,
|
||||
num_training_steps=total_training_steps,
|
||||
)
|
||||
|
||||
# Manual epoch management instead of itertools.cycle to support DistributedSampler.set_epoch()
|
||||
grad_accum_steps = max(int(grad_accum_steps), 1)
|
||||
data_epoch = 0
|
||||
train_iter = iter(train_loader)
|
||||
|
||||
def get_next_batch():
|
||||
"""Get next batch, handles epoch boundary and DistributedSampler."""
|
||||
nonlocal train_iter, data_epoch
|
||||
try:
|
||||
return next(train_iter)
|
||||
except StopIteration:
|
||||
data_epoch += 1
|
||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||
sampler = getattr(train_loader, 'sampler', None)
|
||||
if hasattr(sampler, 'set_epoch'):
|
||||
sampler.set_epoch(data_epoch)
|
||||
train_iter = iter(train_loader)
|
||||
return next(train_iter)
|
||||
|
||||
with tracker.live():
|
||||
for step in range(num_iters):
|
||||
tracker.step = step
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Gradient accumulation: accumulate gradients over micro-batches before optimizer step
|
||||
loss_dict = {}
|
||||
for micro_step in range(grad_accum_steps):
|
||||
batch = get_next_batch()
|
||||
processed = batch_processor(batch)
|
||||
|
||||
# Only sync gradients on the last micro-batch
|
||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||
|
||||
with sync_context:
|
||||
with accelerator.autocast(dtype=torch.bfloat16):
|
||||
outputs = model(
|
||||
processed["text_tokens"],
|
||||
processed["text_mask"],
|
||||
processed["audio_feats"],
|
||||
processed["audio_mask"],
|
||||
processed["loss_mask"],
|
||||
processed["position_ids"],
|
||||
processed["labels"],
|
||||
progress=step / max(1, num_iters),
|
||||
)
|
||||
|
||||
total_loss = 0.0
|
||||
for key, value in outputs.items():
|
||||
if key.startswith("loss/"):
|
||||
weight = lambdas.get(key, 1.0)
|
||||
loss_value = value * weight / grad_accum_steps
|
||||
total_loss = total_loss + loss_value
|
||||
# Record raw loss from last micro-batch for logging
|
||||
loss_dict[key] = value.detach()
|
||||
|
||||
# Accumulate gradients (normalized by grad_accum_steps)
|
||||
accelerator.backward(total_loss)
|
||||
|
||||
# After all micro-batches, do unscale / grad_norm / step
|
||||
scaler = getattr(accelerator, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# Use large max_norm to only compute grad_norm without actual clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
||||
|
||||
accelerator.step(optimizer)
|
||||
accelerator.update()
|
||||
scheduler.step()
|
||||
|
||||
if step % log_interval == 0:
|
||||
loss_values = {k: v.item() if isinstance(v, torch.Tensor) else float(v) for k, v in loss_dict.items()}
|
||||
loss_values["lr"] = float(optimizer.param_groups[0]["lr"])
|
||||
# Approximate epoch: seen samples / total samples (considering grad_accum and batch_size)
|
||||
epoch = (step * grad_accum_steps * batch_size) / max(1, num_train_samples)
|
||||
loss_values["epoch"] = float(epoch)
|
||||
loss_values["grad_norm"] = float(grad_norm)
|
||||
tracker.log_metrics(loss_values, split="train")
|
||||
|
||||
if val_loader is not None and step % valid_interval == 0 and step != 0:
|
||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas)
|
||||
|
||||
if step % save_interval == 0 and accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path)
|
||||
|
||||
if accelerator.rank == 0:
|
||||
save_checkpoint(model, optimizer, scheduler, save_dir, num_iters, pretrained_path)
|
||||
if writer:
|
||||
writer.close()
|
||||
|
||||
|
||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas):
|
||||
model.eval()
|
||||
losses = []
|
||||
num_batches = 0
|
||||
max_val_batches = 10
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
if num_batches >= max_val_batches:
|
||||
break
|
||||
processed = batch_processor(batch)
|
||||
with accelerator.autocast(dtype=torch.bfloat16):
|
||||
outputs = model(
|
||||
processed["text_tokens"],
|
||||
processed["text_mask"],
|
||||
processed["audio_feats"],
|
||||
processed["audio_mask"],
|
||||
processed["loss_mask"],
|
||||
processed["position_ids"],
|
||||
processed["labels"],
|
||||
progress=0.0,
|
||||
sample_generate=False,
|
||||
)
|
||||
total = 0.0
|
||||
for key, value in outputs.items():
|
||||
if key.startswith("loss/"):
|
||||
total += lambdas.get(key, 1.0) * value
|
||||
losses.append(total.detach())
|
||||
num_batches += 1
|
||||
|
||||
if losses:
|
||||
mean_loss = torch.stack(losses).mean()
|
||||
# All-reduce validation loss across processes for global average
|
||||
accelerator.all_reduce(mean_loss)
|
||||
tracker.log_metrics({"loss": mean_loss.item()}, split="val")
|
||||
model.train()
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None):
|
||||
"""
|
||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||
"""
|
||||
import shutil
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
tag = "latest" if step == 0 else f"step_{step:07d}"
|
||||
folder = save_dir / tag
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
unwrapped = model.module if hasattr(model, "module") else model
|
||||
full_state = unwrapped.state_dict()
|
||||
lora_cfg = unwrapped.lora_config
|
||||
|
||||
if lora_cfg is not None:
|
||||
# LoRA finetune: save only lora_A/lora_B weights
|
||||
state_dict = {k: v for k, v in full_state.items() if "lora_" in k}
|
||||
if SAFETENSORS_AVAILABLE:
|
||||
save_file(state_dict, folder / "lora_weights.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "lora_weights.ckpt")
|
||||
else:
|
||||
# Full finetune: save non-vae weights to model.safetensors
|
||||
state_dict = {k: v for k, v in full_state.items() if not k.startswith("audio_vae.")}
|
||||
if SAFETENSORS_AVAILABLE:
|
||||
save_file(state_dict, folder / "model.safetensors")
|
||||
else:
|
||||
torch.save({"state_dict": state_dict}, folder / "pytorch_model.bin")
|
||||
|
||||
# Copy config files from pretrained path
|
||||
if pretrained_path:
|
||||
pretrained_dir = Path(pretrained_path)
|
||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
||||
for fname in files_to_copy:
|
||||
src = pretrained_dir / fname
|
||||
if src.exists():
|
||||
shutil.copy2(src, folder / fname)
|
||||
|
||||
torch.save(optimizer.state_dict(), folder / "optimizer.pth")
|
||||
torch.save(scheduler.state_dict(), folder / "scheduler.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from voxcpm.training.config import load_yaml_config
|
||||
|
||||
args = argbind.parse_args()
|
||||
config_file = args.get("config_path")
|
||||
# If YAML config provided, use YAML args to call train
|
||||
if config_file:
|
||||
yaml_args = load_yaml_config(config_file)
|
||||
train(**yaml_args)
|
||||
else:
|
||||
# Otherwise use command line args (parsed by argbind)
|
||||
with argbind.scope(args):
|
||||
train()
|
||||
|
||||
Reference in New Issue
Block a user