Update: VoxCPM1.5 and fine-tuning supprt

This commit is contained in:
Labmem-Zhouyx
2025-12-05 21:00:01 +08:00
parent d1bb6aaf41
commit 461ad7e506
29 changed files with 2928 additions and 228 deletions

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import contextlib
import time
from pathlib import Path
from typing import Dict, Optional
class TrainingTracker:
"""
Lightweight tracker inspired by the minimcpm-audio training workflow.
It keeps track of the current global step, prints rank-aware messages,
optionally writes to TensorBoard via a provided writer, and stores progress
in a logfile for later inspection.
"""
def __init__(
self,
*,
writer=None,
log_file: Optional[str] = None,
rank: int = 0,
):
self.writer = writer
self.log_file = Path(log_file) if log_file else None
if self.log_file:
self.log_file.parent.mkdir(parents=True, exist_ok=True)
self.rank = rank
self.step = 0
# Record the time of the last log to calculate the interval
self._last_log_time: float | None = None
# ------------------------------------------------------------------ #
# Logging helpers
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")
def log_metrics(self, metrics: Dict[str, float], split: str):
if self.rank == 0:
now = time.time()
dt_str = ""
if self._last_log_time is not None:
dt = now - self._last_log_time
dt_str = f", log interval: {dt:.2f}s"
self._last_log_time = now
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
if self.writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{split}/{key}", value, self.step)
def done(self, split: str, message: str):
self.print(f"[{split}] {message}")
# ------------------------------------------------------------------ #
# State dict
# ------------------------------------------------------------------ #
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state):
self.step = int(state.get("step", 0))
# ------------------------------------------------------------------ #
# Context manager compatibility (for parity with minicpm-audio code)
# ------------------------------------------------------------------ #
@contextlib.contextmanager
def live(self):
yield