mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, Optional
|
|
|
|
|
|
class TrainingTracker:
|
|
"""
|
|
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
|
|
|
It keeps track of the current global step, prints rank-aware messages,
|
|
optionally writes to TensorBoard via a provided writer, and stores progress
|
|
in a logfile for later inspection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
writer=None,
|
|
log_file: Optional[str] = None,
|
|
rank: int = 0,
|
|
):
|
|
self.writer = writer
|
|
self.log_file = Path(log_file) if log_file else None
|
|
if self.log_file:
|
|
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
|
self.rank = rank
|
|
self.step = 0
|
|
# Record the time of the last log to calculate the interval
|
|
self._last_log_time: float | None = None
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Logging helpers
|
|
# ------------------------------------------------------------------ #
|
|
def print(self, message: str):
|
|
if self.rank == 0:
|
|
print(message, flush=True)
|
|
if self.log_file:
|
|
with self.log_file.open("a", encoding="utf-8") as f:
|
|
f.write(message + "\n")
|
|
|
|
def log_metrics(self, metrics: Dict[str, float], split: str):
|
|
if self.rank == 0:
|
|
now = time.time()
|
|
dt_str = ""
|
|
if self._last_log_time is not None:
|
|
dt = now - self._last_log_time
|
|
dt_str = f", log interval: {dt:.2f}s"
|
|
self._last_log_time = now
|
|
|
|
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
|
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
|
if self.writer is not None:
|
|
for key, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
|
|
|
def done(self, split: str, message: str):
|
|
self.print(f"[{split}] {message}")
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# State dict
|
|
# ------------------------------------------------------------------ #
|
|
def state_dict(self):
|
|
return {"step": self.step}
|
|
|
|
def load_state_dict(self, state):
|
|
self.step = int(state.get("step", 0))
|
|
|
|
# ------------------------------------------------------------------ #
|
|
# Context manager compatibility (for parity with minicpm-audio code)
|
|
# ------------------------------------------------------------------ #
|
|
@contextlib.contextmanager
|
|
def live(self):
|
|
yield
|
|
|