mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
78
src/voxcpm/training/tracker.py
Normal file
78
src/voxcpm/training/tracker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class TrainingTracker:
|
||||
"""
|
||||
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
||||
|
||||
It keeps track of the current global step, prints rank-aware messages,
|
||||
optionally writes to TensorBoard via a provided writer, and stores progress
|
||||
in a logfile for later inspection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
writer=None,
|
||||
log_file: Optional[str] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
self.writer = writer
|
||||
self.log_file = Path(log_file) if log_file else None
|
||||
if self.log_file:
|
||||
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rank = rank
|
||||
self.step = 0
|
||||
# Record the time of the last log to calculate the interval
|
||||
self._last_log_time: float | None = None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Logging helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def print(self, message: str):
|
||||
if self.rank == 0:
|
||||
print(message, flush=True)
|
||||
if self.log_file:
|
||||
with self.log_file.open("a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, float], split: str):
|
||||
if self.rank == 0:
|
||||
now = time.time()
|
||||
dt_str = ""
|
||||
if self._last_log_time is not None:
|
||||
dt = now - self._last_log_time
|
||||
dt_str = f", log interval: {dt:.2f}s"
|
||||
self._last_log_time = now
|
||||
|
||||
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
||||
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
||||
if self.writer is not None:
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
||||
|
||||
def done(self, split: str, message: str):
|
||||
self.print(f"[{split}] {message}")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# State dict
|
||||
# ------------------------------------------------------------------ #
|
||||
def state_dict(self):
|
||||
return {"step": self.step}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.step = int(state.get("step", 0))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Context manager compatibility (for parity with minicpm-audio code)
|
||||
# ------------------------------------------------------------------ #
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user