Update: VoxCPM1.5 and fine-tuning supprt

This commit is contained in:
Labmem-Zhouyx
2025-12-05 21:00:01 +08:00
parent d1bb6aaf41
commit 3443dbb212
29 changed files with 2928 additions and 228 deletions

View File

@@ -69,7 +69,7 @@ def load_model(args) -> VoxCPM:
# Otherwise, try from_pretrained (Hub); exit on failure
try:
model = VoxCPM.from_pretrained(
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
load_denoiser=not getattr(args, "no_denoiser", False),
zipenhancer_model_id=zipenhancer_path,
cache_dir=getattr(args, "cache_dir", None),
@@ -120,11 +120,11 @@ def cmd_clone(args):
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
@@ -152,11 +152,11 @@ def cmd_synthesize(args):
)
# Save audio
sf.write(str(output_path), audio_array, 16000)
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
print(f"Saved audio to: {output_path}")
# Stats
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f"Duration: {duration:.2f}s")
@@ -198,9 +198,9 @@ def cmd_batch(args):
denoise=args.denoise and prompt_audio_path is not None
)
output_file = output_dir / f"output_{i:03d}.wav"
sf.write(str(output_file), audio_array, 16000)
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
duration = len(audio_array) / 16000
duration = len(audio_array) / model.tts_model.sample_rate
print(f" Saved: {output_file} ({duration:.2f}s)")
success_count += 1
@@ -250,7 +250,7 @@ Examples:
# Model loading parameters
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")

View File

@@ -45,6 +45,7 @@ class VoxCPM:
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
**kwargs,
):
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
@@ -52,6 +53,7 @@ class VoxCPM:
Args:
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
load_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
zipenhancer_model_id: Denoiser model id or path for ModelScope
acoustic noise suppression.
cache_dir: Custom cache directory for the snapshot.
@@ -87,6 +89,7 @@ class VoxCPM:
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
**kwargs,
)

View File

@@ -19,19 +19,27 @@ limitations under the License.
"""
import os
from typing import Tuple, Union, Generator, List
from typing import Tuple, Union, Generator, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import warnings
from einops import rearrange
from pydantic import BaseModel
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
from tqdm import tqdm
from transformers import LlamaTokenizerFast
from ..modules.audiovae import AudioVAE
from ..modules.audiovae import AudioVAE, AudioVAEConfig
from ..modules.layers import ScalarQuantizationLayer
from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
@@ -66,10 +74,31 @@ class VoxCPMConfig(BaseModel):
encoder_config: VoxCPMEncoderConfig
dit_config: VoxCPMDitConfig
audio_vae_config: Optional[AudioVAEConfig] = None
max_length: int = 4096
device: str = "cuda"
dtype: str = "bfloat16"
dit_mean_mode: bool = False
class LoRAConfig(BaseModel):
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
enable_proj: bool = False # Apply LoRA to projection Linear layers
r: int = 8
alpha: int = 16
dropout: float = 0.0
# Target linear layer names for LM & DiT (matched by attribute name)
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
# Projection layer attribute names to find on VoxCPMModel
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
VoxCPMConfig.model_rebuild()
class VoxCPMModel(nn.Module):
@@ -78,9 +107,11 @@ class VoxCPMModel(nn.Module):
config: VoxCPMConfig,
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
@@ -128,6 +159,7 @@ class VoxCPMModel(nn.Module):
in_channels=config.feat_dim,
cfm_params=config.dit_config.cfm_config,
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
mean_mode=config.dit_mean_mode,
)
# Projection layers
@@ -145,17 +177,46 @@ class VoxCPMModel(nn.Module):
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
self.stop_actn = nn.SiLU()
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
# Audio VAE
self.audio_vae = audio_vae
self.chunk_size = audio_vae.chunk_size
self.sample_rate = audio_vae.sample_rate
if self.lora_config is not None:
self._apply_lora()
def _apply_lora(self):
"""注入 LoRA 到 LM / DiT / 投影层"""
cfg = self.lora_config
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
# LM: base_lm + residual_lm
if cfg.enable_lm:
for lm in [self.base_lm, self.residual_lm]:
apply_lora_to_named_linear_modules(
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
)
# DiT: feat_decoder.estimator
if cfg.enable_dit:
apply_lora_to_named_linear_modules(
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
)
# 投影层
if cfg.enable_proj:
from ..modules.layers.lora import LoRALinear
for attr_name in cfg.target_proj_modules:
module = getattr(self, attr_name, None)
if isinstance(module, nn.Linear):
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
def optimize(self, disable: bool = False):
if disable:
return self
try:
if disable:
raise ValueError("Optimization disabled by user")
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
@@ -164,17 +225,111 @@ class VoxCPMModel(nn.Module):
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"Error: {e}")
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder
self.feat_decoder.estimator = self.feat_decoder.estimator
print(f"Warning: torch.compile disabled - {e}")
return self
def forward(
self,
text_tokens: torch.Tensor,
text_mask: torch.Tensor,
audio_feats: torch.Tensor,
audio_mask: torch.Tensor,
loss_mask: torch.Tensor,
position_ids: torch.Tensor,
labels: torch.Tensor,
*,
progress: float = 0.0,
sample_generate: bool = False,
):
del position_ids # not used yet
text_tokens = text_tokens.to(self.device, dtype=torch.long)
text_mask = text_mask.to(self.device, dtype=self._dtype())
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
labels = labels.to(self.device, dtype=torch.long)
B, T, P, D = audio_feats.shape
feat_embed = self.feat_encoder(audio_feats)
feat_embed = self.enc_to_lm_proj(feat_embed)
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
if not getattr(self.config.lm_config, "use_mup", False):
scale_emb = 1.0
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
enc_outputs = enc_outputs.to(self._dtype())
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
residual_outputs = residual_outputs.to(self._dtype())
residual_hidden = torch.cat(
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
dim=1,
)
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
target_dtype = self._dtype()
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
feat_cond = torch.cat(
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
dim=1,
)
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
diff_loss = self.feat_decoder.compute_loss(
feat_gt.transpose(1, 2).contiguous(),
dit_hidden,
cond=feat_cond.transpose(1, 2).contiguous(),
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
progress=progress,
)
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
denom = torch.clamp(loss_mask.sum(), min=1.0)
stop_loss = (stop_losses * loss_mask).sum() / denom
feat_pred = None
if sample_generate:
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
feat_pred_seq = self.feat_decoder(
mu=dit_hidden,
patch_size=self.patch_size,
cond=feat_cond_for_sample,
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
else 10,
)
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
return {
"loss/diff": diff_loss,
"loss/stop": stop_loss,
"feat_gt": feat_gt_tensor,
"feat_pred": feat_pred,
}
def _dtype(self):
return get_dtype(self.config.dtype)
def generate(self, *args, **kwargs) -> torch.Tensor:
return next(self._generate(*args, streaming=False, **kwargs))
@@ -238,25 +393,25 @@ class VoxCPMModel(nn.Module):
audio, sr = torchaudio.load(prompt_wav_path)
if audio.size(0) > 1:
audio = audio.mean(dim=0, keepdim=True)
audio = audio.mean(dim=0, keepdim=True)
if sr != self.sample_rate:
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# (B, D, T)
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,
-1,
self.patch_size,
).permute(1, 2, 0)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
audio_length = audio_feat.size(0)
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token = torch.cat([text_token, text_pad_token])
@@ -288,7 +443,7 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -314,7 +469,6 @@ class VoxCPMModel(nn.Module):
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield decode_audio
@torch.inference_mode()
@@ -331,13 +485,11 @@ class VoxCPMModel(nn.Module):
prompt_wav_path: prompt audio path (required)
Returns:
prompt_cache: dict with text tokens and audio features
prompt_cache: dict with prompt_text (raw text) and audio features.
Text tokenization will be done during generation for consistency.
"""
if not prompt_text or not prompt_wav_path:
raise ValueError("prompt_text and prompt_wav_path are required")
# build text tokens
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
# load audio
audio, sr = torchaudio.load(prompt_wav_path)
@@ -350,7 +502,9 @@ class VoxCPMModel(nn.Module):
patch_len = self.patch_size * self.chunk_size
if audio.size(1) % patch_len != 0:
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
padding_size = patch_len - audio.size(1) % patch_len
audio = torch.nn.functional.pad(audio, (padding_size, 0))
# extract audio features
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
@@ -360,10 +514,9 @@ class VoxCPMModel(nn.Module):
-1,
self.patch_size,
).permute(1, 2, 0) # (D, T, P)
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
# build prompt cache
# build prompt cache - only save raw text and audio features
prompt_cache = {
"text_token": text_token,
"prompt_text": prompt_text,
"audio_feat": audio_feat,
}
@@ -373,7 +526,7 @@ class VoxCPMModel(nn.Module):
def merge_prompt_cache(
self,
original_cache: dict,
new_text_token: torch.Tensor,
new_text: str,
new_audio_feat: torch.Tensor,
):
"""
@@ -381,38 +534,42 @@ class VoxCPMModel(nn.Module):
Args:
original_cache: original prompt cache
new_text_token: newly generated text tokens
new_text: newly generated text
new_audio_feat: newly generated audio features
Returns:
merged_cache: merged cache
merged_cache: merged cache with prompt_text and audio_feat
"""
if original_cache is None:
return {
"text_token": new_text_token,
"prompt_text": new_text,
"audio_feat": new_audio_feat,
}
original_text_token = original_cache["text_token"]
original_prompt_text = original_cache["prompt_text"]
original_audio_feat = original_cache["audio_feat"]
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
# Merge text by concatenation
merged_prompt_text = original_prompt_text + new_text
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
# build new cache
merged_cache = {
"text_token": merged_text_token,
"prompt_text": merged_prompt_text,
"audio_feat": merged_audio_feat,
}
return merged_cache
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
def generate_with_prompt_cache_streaming(
self, *args, **kwargs
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
@torch.inference_mode()
def _generate_with_prompt_cache(
self,
@@ -453,14 +610,14 @@ class VoxCPMModel(nn.Module):
retry_badcase = False
# get prompt from cache
if prompt_cache is None:
prompt_text_token = torch.empty(0, dtype=torch.int32)
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
text = target_text
else:
prompt_text_token = prompt_cache["text_token"]
prompt_audio_feat = prompt_cache["audio_feat"]
# build target text tokens
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
prompt_text = prompt_cache["prompt_text"]
text = prompt_text + target_text
text_token = torch.LongTensor(self.text_tokenizer(text))
text_token = torch.cat(
[
text_token,
@@ -472,6 +629,8 @@ class VoxCPMModel(nn.Module):
],
dim=-1,
)
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
audio_length = prompt_audio_feat.size(0)
text_length = text_token.shape[0]
@@ -501,7 +660,7 @@ class VoxCPMModel(nn.Module):
audio_feat,
audio_mask,
min_len=min_len,
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
streaming=streaming,
@@ -530,7 +689,6 @@ class VoxCPMModel(nn.Module):
break
if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
yield (
decode_audio,
@@ -556,6 +714,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps: int = 10,
cfg_value: float = 2.0,
streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
"""Core inference method for audio generation.
@@ -628,7 +787,7 @@ class VoxCPMModel(nn.Module):
1, 2
) # [b, p, d]
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
curr_embed = self.enc_to_lm_proj(curr_embed)
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
@@ -636,8 +795,9 @@ class VoxCPMModel(nn.Module):
if streaming:
# return the last three predicted latent features to provide enough context for smooth decoding
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
@@ -656,35 +816,138 @@ class VoxCPMModel(nn.Module):
if not streaming:
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True):
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae = AudioVAE()
audio_vae_config = getattr(config, 'audio_vae_config', None)
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
vae_state_dict = torch.load(
os.path.join(path, "audiovae.pth"),
map_location="cpu",
weights_only=True,
)["state_dict"]
model = cls(config, tokenizer, audio_vae)
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
model = cls(config, tokenizer, audio_vae, lora_config)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)
else: # training mode
for name, param in model.named_parameters():
if "audio_vae" in name: # freeze VAE weights
param.requires_grad = False
continue
if lora_config is not None:
if "lora" not in name: # freeze non-LoRA weights
param.requires_grad = False
model.audio_vae = model.audio_vae.to(torch.float32)
model_state_dict = torch.load(
os.path.join(path, "pytorch_model.bin"),
map_location="cpu",
weights_only=True,
)["state_dict"]
# Try to load from safetensors first, fallback to pytorch_model.bin
safetensors_path = os.path.join(path, "model.safetensors")
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
print(f"Loading model from safetensors: {safetensors_path}")
model_state_dict = load_file(safetensors_path)
elif os.path.exists(pytorch_model_path):
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
checkpoint = torch.load(
pytorch_model_path,
map_location="cpu",
weights_only=True,
)
model_state_dict = checkpoint.get("state_dict", checkpoint)
else:
raise FileNotFoundError(
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
)
for kw, val in vae_state_dict.items():
model_state_dict[f"audio_vae.{kw}"] = val
model.load_state_dict(model_state_dict, strict=True)
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
model.load_state_dict(model_state_dict, strict=False)
if training:
return model
return model.to(model.device).eval().optimize(disable=not optimize)
# ------------------------------------------------------------------ #
# LoRA Weight Management
# ------------------------------------------------------------------ #
def _iter_lora_modules(self):
"""Iterate over all LoRA modules."""
from ..modules.layers.lora import LoRALinear
for module in self.modules():
if isinstance(module, LoRALinear):
yield module
def load_lora_weights(self, lora_path: str, device: str = None):
"""
Load LoRA weights from file, supports calling after torch.compile.
Uses named_parameters() to handle compile's _orig_mod wrapper.
Supports both safetensors and pytorch formats.
Args:
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
device: Target device, defaults to model's current device
Returns:
tuple: (loaded_keys, skipped_keys)
"""
from pathlib import Path
device = device or self.device
lora_path = Path(lora_path)
# Try safetensors first, then fallback to .ckpt
if lora_path.is_dir():
safetensors_file = lora_path / "lora_weights.safetensors"
ckpt_file = lora_path / "lora_weights.ckpt"
else:
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
# Load from safetensors if available
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
state_dict = load_file(str(safetensors_file), device=device)
elif ckpt_file and ckpt_file.exists():
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
else:
raise FileNotFoundError(
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
)
# Build param mapping (handle torch.compile's _orig_mod prefix)
model_params = dict(self.named_parameters())
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
loaded_keys, skipped_keys = [], []
for key, value in state_dict.items():
target_key = key if key in model_params else key_mapping.get(key)
if target_key:
model_params[target_key].data.copy_(value.to(device))
loaded_keys.append(key)
else:
skipped_keys.append(key)
return loaded_keys, skipped_keys
def set_lora_enabled(self, enabled: bool):
"""Enable/disable all LoRA layers."""
for module in self._iter_lora_modules():
module.set_enabled(enabled)
def reset_lora_weights(self):
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
for module in self._iter_lora_modules():
module.reset_lora_parameters()
def get_lora_state_dict(self) -> dict:
"""Get all LoRA parameters (lora_A/lora_B)."""
return {name: param.data.clone()
for name, param in self.named_parameters()
if "lora_" in name}

View File

@@ -1 +1 @@
from .audio_vae import AudioVAE
from .audio_vae import AudioVAE, AudioVAEConfig

View File

@@ -1,11 +1,12 @@
import math
from typing import List, Union
from typing import List, Union, Optional
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from pydantic import BaseModel
def WNConv1d(*args, **kwargs):
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
return self.model(x)
class AudioVAEConfig(BaseModel):
encoder_dim: int = 128
encoder_rates: List[int] = [2, 5, 8, 8]
latent_dim: int = 64
decoder_dim: int = 1536
decoder_rates: List[int] = [8, 8, 5, 2]
depthwise: bool = True
sample_rate: int = 16000
use_noise_block: bool = False
class AudioVAE(nn.Module):
"""
Args:
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 5, 8, 8],
latent_dim: int = 64,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 5, 2],
depthwise: bool = True,
sample_rate: int = 16000,
use_noise_block: bool = False,
config: Optional[AudioVAEConfig] = None,
):
# 如果没有传入config使用默认配置
if config is None:
config = AudioVAEConfig()
super().__init__()
encoder_dim = config.encoder_dim
encoder_rates = config.encoder_rates
latent_dim = config.latent_dim
decoder_dim = config.decoder_dim
decoder_rates = config.decoder_rates
depthwise = config.depthwise
sample_rate = config.sample_rate
use_noise_block = config.use_noise_block
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim

View File

@@ -0,0 +1,133 @@
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRALinear(nn.Module):
"""
LoRA 线性层:直接持有 weight/bias保持与 nn.Linear 相同的 state_dict key 结构。
state_dict 结构:
- weight: 原始权重(与 nn.Linear 一致)
- bias: 原始偏置(与 nn.Linear 一致)
- lora_A: LoRA 低秩矩阵 A
- lora_B: LoRA 低秩矩阵 B
这样设计的好处:加载预训练权重时无需做 key 转换。
"""
def __init__(
self,
base: nn.Linear,
r: int,
alpha: float = 1.0,
dropout: float = 0.0,
):
super().__init__()
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
self.in_features = base.in_features
self.out_features = base.out_features
self.r = r
self.alpha = alpha
self._base_scaling = alpha / r if r > 0 else 0.0
# 使用 buffer 存储 scaling这样修改值不会触发 torch.compile 重编译
# persistent=False 表示不保存到 state_dict避免加载时 missing key
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
# 直接持有 weight 和 bias从原始 Linear 转移过来)
self.weight = base.weight
self.bias = base.bias # 可能是 None
# LoRA 参数
if r > 0:
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
else:
self.register_parameter("lora_A", None)
self.register_parameter("lora_B", None)
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 基础 Linear 计算
result = F.linear(x, self.weight, self.bias)
if self.r <= 0 or self.lora_A is None:
return result
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
return result + self.dropout(lora_out) * self.scaling
def reset_lora_parameters(self):
"""重置 LoRA 参数到初始状态"""
if self.r > 0 and self.lora_A is not None:
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def set_enabled(self, enabled: bool):
"""启用/禁用 LoRA通过 scaling 控制,兼容 torch.compile"""
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
self.scaling.fill_(self._base_scaling if enabled else 0.0)
@property
def enabled(self) -> bool:
return self.scaling.item() != 0.0
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
"""
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module即 q_proj 的上一级)。
"""
parts = name.split(".")
if len(parts) == 1:
return root
parent = root
for p in parts[:-1]:
if not hasattr(parent, p):
return None
parent = getattr(parent, p)
return parent
def apply_lora_to_named_linear_modules(
root: nn.Module,
*,
target_submodule_names: list[str],
r: int,
alpha: float,
dropout: float,
) -> None:
"""
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
例如 target_submodule_names=["q_proj", "v_proj"] 时,
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
"""
for full_name, module in list(root.named_modules()):
if not isinstance(module, nn.Linear):
continue
short_name = full_name.split(".")[-1]
if short_name not in target_submodule_names:
continue
parent = _get_parent_module(root, full_name)
if parent is None:
continue
# 用 LoRALinear 替换原始 Linear
lora_layer = LoRALinear(
base=module,
r=r,
alpha=alpha,
dropout=dropout,
)
setattr(parent, short_name, lora_layer)

View File

@@ -1,20 +1,29 @@
from typing import List, Tuple
import torch
from typing import List
from .local_dit import VoxCPMLocDiT
import math
import torch.nn.functional as F
from torch.func import jvp
from pydantic import BaseModel
from .local_dit import VoxCPMLocDiT
class CfmConfig(BaseModel):
sigma_min: float = 1e-06
sigma_min: float = 1e-6
solver: str = "euler"
t_scheduler: str = "log-norm"
training_cfg_rate: float = 0.1
inference_cfg_rate: float = 1.0
reg_loss_type: str = "l1"
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
noise_cond_scale: float = 0.0
class UnifiedCFM(torch.nn.Module):
def __init__(
self,
in_channels,
in_channels: int,
cfm_params: CfmConfig,
estimator: VoxCPMLocDiT,
mean_mode: bool = False,
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
self.solver = cfm_params.solver
self.sigma_min = cfm_params.sigma_min
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
self.reg_loss_type = cfm_params.reg_loss_type
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
self.noise_cond_scale = cfm_params.noise_cond_scale
self.in_channels = in_channels
self.mean_mode = mean_mode
# Just change the architecture of the estimator here
self.estimator = estimator
# ------------------------------------------------------------------ #
# Inference
# ------------------------------------------------------------------ #
@torch.inference_mode()
def forward(
self,
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
sway_sampling_coef: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
n_timesteps (int): number of diffusion steps
cond: Not used but kept for future purposes
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
b, c = mu.shape
b, _ = mu.shape
t = patch_size
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
# Sway sampling strategy
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
return self.solve_euler(
x=z,
t_span=t_span,
mu=mu,
cond=cond,
cfg_value=cfg_value,
use_cfg_zero_star=use_cfg_zero_star,
)
def optimized_scale(self, positive_flat, negative_flat):
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
st_star = dot_product / squared_norm
return st_star
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
cfg_value: float = 1.0,
use_cfg_zero_star: bool = True,
):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats)
cond: condition -- prefix prompt
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
"""
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
sol = []
zero_init_steps = max(1, int(len(t_span) * 0.04))
for step in range(1, len(t_span)):
if use_cfg_zero_star and step <= zero_init_steps:
dphi_dt = 0.
dphi_dt = torch.zeros_like(x)
else:
# Classifier-Free Guidance inference introduced in VoiceBox
b = x.size(0)
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
x_in[:b], x_in[b:] = x, x
mu_in[:b] = mu
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
dt = t - t_span[step + 1]
return sol[-1]
# ------------------------------------------------------------------ #
# Training loss
# ------------------------------------------------------------------ #
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
weights = 1.0 / ((losses + epsilon).pow(p))
if mask is not None:
weights = weights * mask
return weights.detach()
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
batch_size = x.shape[0]
if self.t_scheduler == "log-norm":
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
r = torch.sigmoid(s_r)
t = torch.sigmoid(s_t)
elif self.t_scheduler == "uniform":
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
else:
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
r, t = torch.where(
mask,
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
torch.stack([t, t], dim=0),
)
return r.squeeze(), t.squeeze()
def compute_loss(
self,
x1: torch.Tensor,
mu: torch.Tensor,
cond: torch.Tensor | None = None,
tgt_mask: torch.Tensor | None = None,
progress: float = 0.0,
):
b, _, _ = x1.shape
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1)
if cond is None:
cond = torch.zeros_like(x1)
noisy_mask = torch.rand(b, device=x1.device) > (
1.0
- (
self.noise_cond_prob_range[0]
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
)
)
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
ratio_r_neq_t = (
self.ratio_r_neq_t_range[0]
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
if self.mean_mode
else 0.0
)
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
r_ = r.detach().clone()
t_ = t.detach().clone()
z = torch.randn_like(x1)
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
v = z - x1
def model_fn(z_sample, r_sample, t_sample):
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
if self.mean_mode:
v_r = torch.zeros_like(r)
v_t = torch.ones_like(t)
from torch.backends.cuda import sdp_kernel
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
else:
u_pred = model_fn(y, r, t)
u_tgt = v
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
if tgt_mask is not None:
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
loss = (weights * losses).sum() / torch.sum(tgt_mask)
else:
loss = losses.mean()
return loss

View File

@@ -0,0 +1,28 @@
"""
Training utilities for VoxCPM fine-tuning.
This package mirrors the training mechanics used in the minicpm-audio
tooling while relying solely on local audio-text datasets managed via
the HuggingFace ``datasets`` library.
"""
from .accelerator import Accelerator
from .tracker import TrainingTracker
from .data import (
load_audio_text_datasets,
HFVoxCPMDataset,
build_dataloader,
BatchProcessor,
)
from .state import TrainingState
__all__ = [
"Accelerator",
"TrainingTracker",
"HFVoxCPMDataset",
"BatchProcessor",
"TrainingState",
"load_audio_text_datasets",
"build_dataloader",
]

View File

@@ -0,0 +1,166 @@
from __future__ import annotations
import contextlib
import os
import random
import typing
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
from torch.nn.parallel import DistributedDataParallel
class Accelerator:
"""
Simplified accelerator that mirrors the behaviour of the minicpm-audio
training utilities. It initializes a distributed process group when
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
preparing models/dataloaders for DDP.
"""
def __init__(self, amp: bool = False, seed: int = 42):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size > 1 and not dist.is_initialized():
dist.init_process_group("nccl", init_method="env://")
self.rank = dist.get_rank() if dist.is_initialized() else 0
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
self.amp = amp
# Set random seed to ensure model initialization consistency
self._set_seed(seed)
class DummyScaler:
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
self.device_ctx = (
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
)
self._ddp_model = None # For no_sync support
def _set_seed(self, seed: int):
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def __enter__(self):
if self.device_ctx is not None:
self.device_ctx.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self.device_ctx is not None:
self.device_ctx.__exit__(exc_type, exc_value, traceback)
def barrier(self):
"""Synchronize all processes"""
if dist.is_initialized():
dist.barrier()
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
"""All-reduce tensor across processes"""
if dist.is_initialized():
dist.all_reduce(tensor, op=op)
return tensor
# ------------------------------------------------------------------ #
# Model helpers
# ------------------------------------------------------------------ #
def prepare_model(self, model: torch.nn.Module, **kwargs):
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
model.device = self.device
model = model.to(self.device)
if self.world_size > 1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
self._ddp_model = model # Save DDP model reference for no_sync support
return model
@contextlib.contextmanager
def no_sync(self):
"""
Context manager to skip gradient synchronization during gradient accumulation.
Only used outside the last micro-batch.
"""
if self._ddp_model is not None:
with self._ddp_model.no_sync():
yield
else:
yield
@property
def device(self):
if torch.cuda.is_available():
return torch.device("cuda", self.local_rank)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
# ------------------------------------------------------------------ #
# AMP helpers
# ------------------------------------------------------------------ #
def autocast(self, *args, **kwargs):
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
def backward(self, loss: torch.Tensor):
self.scaler.scale(loss).backward()
def step(self, optimizer: torch.optim.Optimizer):
self.scaler.step(optimizer)
def update(self):
self.scaler.update()
# ------------------------------------------------------------------ #
# Data helpers
# ------------------------------------------------------------------ #
def prepare_dataloader(
self,
dataset: typing.Iterable,
*,
batch_size: int,
num_workers: int = 0,
shuffle: bool = True,
collate_fn=None,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
if self.world_size > 1:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
)
shuffle = False
else:
sampler = None
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=drop_last,
pin_memory=True,
)
@staticmethod
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
return model.module if hasattr(model, "module") else model

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
import argbind
import yaml
from pathlib import Path
from typing import Dict, Any
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
"""
Load a YAML configuration file into a dictionary suitable for argbind.
"""
path = Path(path)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
return data
def parse_args_with_config(config_path: str | Path | None = None):
"""
Helper to unify CLI arguments and YAML configuration.
Usage mirrors minicpm-audio:
args = parse_args_with_config("conf/voxcpm/finetune.yml")
with argbind.scope(args):
...
"""
cli_args = argbind.parse_args()
if config_path is None:
return cli_args
yaml_args = load_yaml_config(config_path)
with argbind.scope(cli_args):
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
cli_args.update(yaml_args)
return cli_args

214
src/voxcpm/training/data.py Normal file
View File

@@ -0,0 +1,214 @@
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import argbind
import torch
from datasets import Audio, Dataset, DatasetDict, load_dataset
from torch.utils.data import Dataset as TorchDataset
from ..model.voxcpm import VoxCPMConfig
from ..modules.audiovae import AudioVAE
from .packers import AudioFeatureProcessingPacker
DEFAULT_TEXT_COLUMN = "text"
DEFAULT_AUDIO_COLUMN = "audio"
DEFAULT_ID_COLUMN = "dataset_id"
@argbind.bind()
def load_audio_text_datasets(
train_manifest: str,
val_manifest: str = "",
text_column: str = DEFAULT_TEXT_COLUMN,
audio_column: str = DEFAULT_AUDIO_COLUMN,
dataset_id_column: str = DEFAULT_ID_COLUMN,
sample_rate: int = 16_000,
num_proc: int = 1,
) -> Tuple[Dataset, Optional[Dataset]]:
data_files = {"train": train_manifest}
if val_manifest:
data_files["validation"] = val_manifest
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
def prepare(ds: Dataset) -> Dataset:
if audio_column not in ds.column_names:
raise ValueError(f"Expected '{audio_column}' column in manifest.")
# We cast to Audio to ensure proper handling during training,
# but for length calculation we might need raw path or duration if available.
# HF datasets usually don't compute duration automatically for 'Audio' column.
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
if audio_column != DEFAULT_AUDIO_COLUMN:
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
if text_column != DEFAULT_TEXT_COLUMN:
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
if dataset_id_column and dataset_id_column in ds.column_names:
if dataset_id_column != DEFAULT_ID_COLUMN:
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
else:
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
return ds
train_ds = prepare(dataset_dict["train"])
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
return train_ds, val_ds
def compute_sample_lengths(
ds: Dataset,
audio_vae_fps: int = 25,
patch_size: int = 1,
) -> List[int]:
"""
预估每个样本经过 packer 之后的大致序列长度text+audio用于过滤超长样本。
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
- 文本长度: len(text_ids)
- 音频长度:
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
t_seq = ceil(t_vae / patch_size)
- 序列总长约为: text_len + t_seq + 2
"""
lengths: List[int] = []
has_duration = "duration" in ds.column_names
for i in range(len(ds)):
item = ds[i]
text_len = len(item["text_ids"])
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
if has_duration:
duration = float(item["duration"])
else:
audio = item[DEFAULT_AUDIO_COLUMN]
duration = len(audio["array"]) / float(audio["sampling_rate"])
t_vae = math.ceil(duration * audio_vae_fps)
t_seq = math.ceil(t_vae / patch_size)
total_len = text_len + t_seq + 2
lengths.append(total_len)
return lengths
class HFVoxCPMDataset(TorchDataset):
"""
Thin wrapper around a tokenized HuggingFace dataset that returns
PyTorch-friendly samples.
"""
def __init__(self, dataset: Dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
audio = item[DEFAULT_AUDIO_COLUMN]
return {
"text_ids": item["text_ids"],
"audio_array": audio["array"],
"audio_sampling_rate": audio["sampling_rate"],
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
"is_prompt": item.get("is_prompt", False),
}
@staticmethod
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
if not seqs:
return torch.empty(0)
max_len = max(seq.shape[0] for seq in seqs)
padded = []
for seq in seqs:
if seq.shape[0] < max_len:
pad_width = (0, max_len - seq.shape[0])
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
padded.append(seq)
return torch.stack(padded)
@classmethod
def collate_fn(cls, batch: List[Dict]):
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
return {
"text_tokens": text_padded,
"audio_tokens": audio_padded,
"task_ids": task_ids,
"dataset_ids": dataset_ids,
"is_prompts": is_prompts,
}
class BatchProcessor:
"""
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
the minicpm-audio mechanics.
"""
def __init__(
self,
*,
config: VoxCPMConfig,
audio_vae: AudioVAE,
dataset_cnt: int,
device: torch.device,
):
self.device = device
self.dataset_cnt = dataset_cnt
self.audio_vae = audio_vae
self.audio_vae.to(device)
self.packer = AudioFeatureProcessingPacker(
dataset_cnt=dataset_cnt,
max_len=config.max_length,
patch_size=config.patch_size,
feat_dim=config.feat_dim,
audio_vae=self.audio_vae,
)
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
audio_tokens = batch["audio_tokens"].to(self.device)
text_tokens = batch["text_tokens"].to(self.device)
task_ids = batch["task_ids"].to(self.device)
dataset_ids = batch["dataset_ids"].to(self.device)
packed = self.packer(
audio_tokens=audio_tokens,
text_tokens=text_tokens,
task_ids=task_ids,
dataset_ids=dataset_ids,
is_prompts=batch["is_prompts"],
)
return packed
def build_dataloader(
hf_dataset: Dataset,
*,
accelerator,
batch_size: int,
num_workers: int,
drop_last: bool = False,
) -> torch.utils.data.DataLoader:
torch_dataset = HFVoxCPMDataset(hf_dataset)
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
return accelerator.prepare_dataloader(
torch_dataset,
batch_size=batch_size,
num_workers=num_workers,
shuffle=True,
collate_fn=HFVoxCPMDataset.collate_fn,
drop_last=drop_last,
)

View File

@@ -0,0 +1,289 @@
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from einops import rearrange
class AudioFeatureProcessingPacker:
"""
Adapted from the minicpm-audio training utilities. It converts raw text and
audio tokens into the packed multimodal representation required by VoxCPM.
"""
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
self.audio_start_id = 101
self.audio_end_id = 102
# unused now
self.audio_prompt_start_id = 103
self.audio_prompt_end_id = 104
self.text_eos_token_id = 2
self.patch_size = patch_size
self.patch_len = audio_vae.hop_length * self.patch_size
self.feat_dim = feat_dim
self.dataset_cnt = max(dataset_cnt, 1)
self.max_len = max_len
self.audio_vae = audio_vae
self.process_functions = {"tts": self.process_tts_data}
self.task_id_map = {"tts": 1}
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _first_pad_position(tokens: torch.Tensor):
positions = (tokens == -100).nonzero(as_tuple=True)
if positions[0].numel() == 0:
return None
return int(positions[0][0])
def unpad_text_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def unpad_audio_tokens(self, tokens: torch.Tensor):
pad_pos = self._first_pad_position(tokens)
return tokens if pad_pos is None else tokens[:pad_pos]
def encode_audio(self, wav: torch.Tensor):
"""
Encode raw waveform into latent features using AudioVAE.
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
We then transpose to [B, T, D] to match downstream expectations.
"""
wav = wav.unsqueeze(0) # [1, T]
wav = wav.unsqueeze(1) # [1, 1, T]
wav_len = wav.size(-1)
if wav_len % self.patch_len != 0:
padding_size = self.patch_len - wav_len % self.patch_len
wav = torch.nn.functional.pad(wav, (0, padding_size))
with torch.no_grad():
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
feat = z.transpose(1, 2) # [1, T', D]
return feat
# ------------------------------------------------------------------ #
# Main entry point
# ------------------------------------------------------------------ #
def __call__(
self,
audio_tokens: torch.Tensor,
text_tokens: torch.Tensor,
task_ids: torch.Tensor,
dataset_ids: torch.Tensor,
is_prompts: List[bool],
) -> Dict[str, torch.Tensor]:
"""
Padding-based batching: each sample in the input batch is processed
independently and then padded to a common length (capped by ``max_len``).
The result tensors all have shape [B, T, ...].
"""
device = audio_tokens.device
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
text_tokens_list: List[torch.Tensor] = []
audio_feats_list: List[torch.Tensor] = []
text_mask_list: List[torch.Tensor] = []
audio_mask_list: List[torch.Tensor] = []
loss_mask_list: List[torch.Tensor] = []
labels_list: List[torch.Tensor] = []
audio_task_ids_list: List[torch.Tensor] = []
audio_dataset_ids_list: List[torch.Tensor] = []
lengths: List[int] = []
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
):
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
unpad_text_token = self.unpad_text_tokens(text_token)
usage = self.id_to_task[task_id]
(
packed_text,
audio_feat,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
audio_duration_consumed[dataset_idx] += audio_duration
text_token_consumed[dataset_idx] += text_token_count
audio_task_id = torch.zeros_like(audio_mask)
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
audio_dataset_id = torch.zeros_like(audio_mask)
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
text_tokens_list.append(packed_text)
text_mask_list.append(text_mask)
audio_feats_list.append(audio_feat)
audio_mask_list.append(audio_mask)
loss_mask_list.append(loss_mask)
labels_list.append(labels)
audio_task_ids_list.append(audio_task_id)
audio_dataset_ids_list.append(audio_dataset_id)
lengths.append(packed_text.shape[0])
# Determine padded length per batch (cap by self.max_len)
if lengths:
max_len = min(self.max_len, max(lengths))
else:
max_len = self.max_len
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
return torch.cat([x, pad], dim=0)
def pad_3d(x: torch.Tensor) -> torch.Tensor:
# x: [T, P, D]
if x.size(0) >= max_len:
return x[: max_len]
pad = torch.zeros(
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
)
return torch.cat([x, pad], dim=0)
if lengths:
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
audio_task_ids_batch = torch.stack(
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
)
audio_dataset_ids_batch = torch.stack(
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
)
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
position_ids_list = []
for L in lengths:
L_clip = min(L, max_len)
pos = torch.arange(0, L_clip, device=device)
if L_clip < max_len:
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
pos = torch.cat([pos, pad], dim=0)
position_ids_list.append(pos)
position_ids = torch.stack(position_ids_list, dim=0)
else:
# Empty batch fallback (shouldn't really happen)
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
text_mask_batch = torch.zeros_like(text_tokens_batch)
audio_feats_batch = torch.zeros(
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
)
audio_mask_batch = torch.zeros_like(text_tokens_batch)
loss_mask_batch = torch.zeros_like(text_tokens_batch)
labels_batch = torch.zeros_like(text_tokens_batch)
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
position_ids = torch.zeros_like(text_tokens_batch)
audio_duration_consumed = audio_duration_consumed.to(torch.long)
text_token_consumed = text_token_consumed.to(torch.long)
return {
"text_tokens": text_tokens_batch,
"audio_feats": audio_feats_batch,
"text_mask": text_mask_batch,
"audio_mask": audio_mask_batch,
"loss_mask": loss_mask_batch,
"position_ids": position_ids,
"labels": labels_batch,
"audio_task_ids": audio_task_ids_batch,
"audio_dataset_ids": audio_dataset_ids_batch,
"audio_duration_consumed": audio_duration_consumed,
"text_token_consumed": text_token_consumed,
}
# ------------------------------------------------------------------ #
# Feature extraction helpers
# ------------------------------------------------------------------ #
def extract_audio_feats(self, audio_data: torch.Tensor):
audio_feats = self.encode_audio(audio_data)
if audio_feats.size(1) % self.patch_size != 0:
audio_feats_ = audio_feats.transpose(1, 2)
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
audio_feats = padding.transpose(1, 2)
audio_duration = audio_feats.size(1) / 25
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
return audio_feats, audio_duration
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
text_token_info = torch.cat(
[
text_token,
torch.tensor(
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
dtype=torch.int32,
device=text_token.device,
),
],
dim=-1,
)
text_token_count = len(text_token)
text_length = text_token_info.shape[0]
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
audio_feat_info = audio_feat_info.squeeze(0)
audio_length = audio_feat_info.shape[0]
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
text_token_info = torch.cat(
[
text_token_info,
text_pad_token,
torch.tensor(
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
dtype=torch.int32,
device=text_token.device,
),
]
)
audio_pad_feat = torch.zeros(
(text_length, self.patch_size, audio_feat_info.size(-1)),
dtype=torch.float32,
device=text_token.device,
)
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
text_token.device
)
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
torch.int32
).to(text_token.device)
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
labels[-2] = 1
return (
text_token_info,
audio_feat_info,
text_mask,
audio_mask,
loss_mask,
labels,
audio_duration,
text_token_count,
)

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class TrainingState:
"""
Container that mirrors the object returned in the minicpm-audio training
loop. It holds persistent references to the model, optimizer, scheduler,
dataloaders and tracker.
"""
generator: object
optimizer: object
scheduler: object
train_loader: object
val_loader: object
tracker: object
batch_processor: object

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
import contextlib
import time
from pathlib import Path
from typing import Dict, Optional
class TrainingTracker:
"""
Lightweight tracker inspired by the minimcpm-audio training workflow.
It keeps track of the current global step, prints rank-aware messages,
optionally writes to TensorBoard via a provided writer, and stores progress
in a logfile for later inspection.
"""
def __init__(
self,
*,
writer=None,
log_file: Optional[str] = None,
rank: int = 0,
):
self.writer = writer
self.log_file = Path(log_file) if log_file else None
if self.log_file:
self.log_file.parent.mkdir(parents=True, exist_ok=True)
self.rank = rank
self.step = 0
# Record the time of the last log to calculate the interval
self._last_log_time: float | None = None
# ------------------------------------------------------------------ #
# Logging helpers
# ------------------------------------------------------------------ #
def print(self, message: str):
if self.rank == 0:
print(message, flush=True)
if self.log_file:
with self.log_file.open("a", encoding="utf-8") as f:
f.write(message + "\n")
def log_metrics(self, metrics: Dict[str, float], split: str):
if self.rank == 0:
now = time.time()
dt_str = ""
if self._last_log_time is not None:
dt = now - self._last_log_time
dt_str = f", log interval: {dt:.2f}s"
self._last_log_time = now
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
if self.writer is not None:
for key, value in metrics.items():
if isinstance(value, (int, float)):
self.writer.add_scalar(f"{split}/{key}", value, self.step)
def done(self, split: str, message: str):
self.print(f"[{split}] {message}")
# ------------------------------------------------------------------ #
# State dict
# ------------------------------------------------------------------ #
def state_dict(self):
return {"step": self.step}
def load_state_dict(self, state):
self.step = int(state.get("step", 0))
# ------------------------------------------------------------------ #
# Context manager compatibility (for parity with minicpm-audio code)
# ------------------------------------------------------------------ #
@contextlib.contextmanager
def live(self):
yield