from __future__ import annotations import contextlib import time from pathlib import Path from typing import Dict, Optional class TrainingTracker: """ Lightweight tracker inspired by the minimcpm-audio training workflow. It keeps track of the current global step, prints rank-aware messages, optionally writes to TensorBoard via a provided writer, and stores progress in a logfile for later inspection. """ def __init__( self, *, writer=None, log_file: Optional[str] = None, rank: int = 0, ): self.writer = writer self.log_file = Path(log_file) if log_file else None if self.log_file: self.log_file.parent.mkdir(parents=True, exist_ok=True) self.rank = rank self.step = 0 # Record the time of the last log to calculate the interval self._last_log_time: float | None = None # ------------------------------------------------------------------ # # Logging helpers # ------------------------------------------------------------------ # def print(self, message: str): if self.rank == 0: print(message, flush=True) if self.log_file: with self.log_file.open("a", encoding="utf-8") as f: f.write(message + "\n") def log_metrics(self, metrics: Dict[str, float], split: str): if self.rank == 0: now = time.time() dt_str = "" if self._last_log_time is not None: dt = now - self._last_log_time dt_str = f", log interval: {dt:.2f}s" self._last_log_time = now formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items()) self.print(f"[{split}] step {self.step}: {formatted}{dt_str}") if self.writer is not None: for key, value in metrics.items(): if isinstance(value, (int, float)): self.writer.add_scalar(f"{split}/{key}", value, self.step) def done(self, split: str, message: str): self.print(f"[{split}] {message}") # ------------------------------------------------------------------ # # State dict # ------------------------------------------------------------------ # def state_dict(self): return {"step": self.step} def load_state_dict(self, state): self.step = int(state.get("step", 0)) # ------------------------------------------------------------------ # # Context manager compatibility (for parity with minicpm-audio code) # ------------------------------------------------------------------ # @contextlib.contextmanager def live(self): yield