Update: VoxCPM1.5 and fine-tuning supprt
This commit is contained in:
@@ -69,7 +69,7 @@ def load_model(args) -> VoxCPM:
|
||||
# Otherwise, try from_pretrained (Hub); exit on failure
|
||||
try:
|
||||
model = VoxCPM.from_pretrained(
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM-0.5B"),
|
||||
hf_model_id=getattr(args, "hf_model_id", "openbmb/VoxCPM1.5"),
|
||||
load_denoiser=not getattr(args, "no_denoiser", False),
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
@@ -120,11 +120,11 @@ def cmd_clone(args):
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
@@ -152,11 +152,11 @@ def cmd_synthesize(args):
|
||||
)
|
||||
|
||||
# Save audio
|
||||
sf.write(str(output_path), audio_array, 16000)
|
||||
sf.write(str(output_path), audio_array, model.tts_model.sample_rate)
|
||||
print(f"Saved audio to: {output_path}")
|
||||
|
||||
# Stats
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f"Duration: {duration:.2f}s")
|
||||
|
||||
|
||||
@@ -198,9 +198,9 @@ def cmd_batch(args):
|
||||
denoise=args.denoise and prompt_audio_path is not None
|
||||
)
|
||||
output_file = output_dir / f"output_{i:03d}.wav"
|
||||
sf.write(str(output_file), audio_array, 16000)
|
||||
sf.write(str(output_file), audio_array, model.tts_model.sample_rate)
|
||||
|
||||
duration = len(audio_array) / 16000
|
||||
duration = len(audio_array) / model.tts_model.sample_rate
|
||||
print(f" Saved: {output_file} ({duration:.2f}s)")
|
||||
success_count += 1
|
||||
|
||||
@@ -250,7 +250,7 @@ Examples:
|
||||
|
||||
# Model loading parameters
|
||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path (overrides Hub download)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM-0.5B", help="Hugging Face repo id (e.g., openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (e.g., openbmb/VoxCPM1.5 or openbmb/VoxCPM-0.5B)")
|
||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||
parser.add_argument("--local-files-only", action="store_true", help="Use only local files (no network)")
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
|
||||
@@ -45,6 +45,7 @@ class VoxCPM:
|
||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||
@@ -52,6 +53,7 @@ class VoxCPM:
|
||||
Args:
|
||||
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
|
||||
load_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
zipenhancer_model_id: Denoiser model id or path for ModelScope
|
||||
acoustic noise suppression.
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
@@ -87,6 +89,7 @@ class VoxCPM:
|
||||
voxcpm_model_path=local_path,
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -19,19 +19,27 @@ limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Tuple, Union, Generator, List
|
||||
from typing import Tuple, Union, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import warnings
|
||||
from einops import rearrange
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
SAFETENSORS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SAFETENSORS_AVAILABLE = False
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaTokenizerFast
|
||||
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from ..modules.audiovae import AudioVAE, AudioVAEConfig
|
||||
from ..modules.layers import ScalarQuantizationLayer
|
||||
from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
@@ -66,10 +74,31 @@ class VoxCPMConfig(BaseModel):
|
||||
|
||||
encoder_config: VoxCPMEncoderConfig
|
||||
dit_config: VoxCPMDitConfig
|
||||
audio_vae_config: Optional[AudioVAEConfig] = None
|
||||
|
||||
max_length: int = 4096
|
||||
device: str = "cuda"
|
||||
dtype: str = "bfloat16"
|
||||
dit_mean_mode: bool = False
|
||||
|
||||
|
||||
class LoRAConfig(BaseModel):
|
||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||
|
||||
r: int = 8
|
||||
alpha: int = 16
|
||||
dropout: float = 0.0
|
||||
|
||||
# Target linear layer names for LM & DiT (matched by attribute name)
|
||||
target_modules_lm: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
target_modules_dit: list[str] = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
||||
# Projection layer attribute names to find on VoxCPMModel
|
||||
target_proj_modules: list[str] = ["enc_to_lm_proj", "lm_to_dit_proj", "res_to_dit_proj"]
|
||||
|
||||
|
||||
VoxCPMConfig.model_rebuild()
|
||||
|
||||
|
||||
class VoxCPMModel(nn.Module):
|
||||
@@ -78,9 +107,11 @@ class VoxCPMModel(nn.Module):
|
||||
config: VoxCPMConfig,
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
@@ -128,6 +159,7 @@ class VoxCPMModel(nn.Module):
|
||||
in_channels=config.feat_dim,
|
||||
cfm_params=config.dit_config.cfm_config,
|
||||
estimator=VoxCPMLocDiT(decoder_config, in_channels=config.feat_dim),
|
||||
mean_mode=config.dit_mean_mode,
|
||||
)
|
||||
|
||||
# Projection layers
|
||||
@@ -145,17 +177,46 @@ class VoxCPMModel(nn.Module):
|
||||
self.stop_proj = nn.Linear(config.lm_config.hidden_size, config.lm_config.hidden_size)
|
||||
self.stop_actn = nn.SiLU()
|
||||
self.stop_head = nn.Linear(config.lm_config.hidden_size, 2, bias=False)
|
||||
self.stop_loss = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
# Audio VAE
|
||||
self.audio_vae = audio_vae
|
||||
self.chunk_size = audio_vae.chunk_size
|
||||
self.sample_rate = audio_vae.sample_rate
|
||||
|
||||
|
||||
if self.lora_config is not None:
|
||||
self._apply_lora()
|
||||
|
||||
def _apply_lora(self):
|
||||
"""注入 LoRA 到 LM / DiT / 投影层"""
|
||||
cfg = self.lora_config
|
||||
lora_kwargs = dict(r=cfg.r, alpha=cfg.alpha, dropout=cfg.dropout)
|
||||
|
||||
# LM: base_lm + residual_lm
|
||||
if cfg.enable_lm:
|
||||
for lm in [self.base_lm, self.residual_lm]:
|
||||
apply_lora_to_named_linear_modules(
|
||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
||||
)
|
||||
|
||||
# DiT: feat_decoder.estimator
|
||||
if cfg.enable_dit:
|
||||
apply_lora_to_named_linear_modules(
|
||||
self.feat_decoder.estimator, target_submodule_names=cfg.target_modules_dit, **lora_kwargs
|
||||
)
|
||||
|
||||
# 投影层
|
||||
if cfg.enable_proj:
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for attr_name in cfg.target_proj_modules:
|
||||
module = getattr(self, attr_name, None)
|
||||
if isinstance(module, nn.Linear):
|
||||
setattr(self, attr_name, LoRALinear(base=module, **lora_kwargs))
|
||||
|
||||
def optimize(self, disable: bool = False):
|
||||
if disable:
|
||||
return self
|
||||
try:
|
||||
if disable:
|
||||
raise ValueError("Optimization disabled by user")
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
try:
|
||||
@@ -164,17 +225,111 @@ class VoxCPMModel(nn.Module):
|
||||
raise ValueError("triton is not installed")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("Warning: VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
self.feat_decoder.estimator = self.feat_decoder.estimator
|
||||
print(f"Warning: torch.compile disabled - {e}")
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_tokens: torch.Tensor,
|
||||
text_mask: torch.Tensor,
|
||||
audio_feats: torch.Tensor,
|
||||
audio_mask: torch.Tensor,
|
||||
loss_mask: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
progress: float = 0.0,
|
||||
sample_generate: bool = False,
|
||||
):
|
||||
del position_ids # not used yet
|
||||
|
||||
text_tokens = text_tokens.to(self.device, dtype=torch.long)
|
||||
text_mask = text_mask.to(self.device, dtype=self._dtype())
|
||||
audio_feats = audio_feats.to(self.device, dtype=self._dtype())
|
||||
audio_mask = audio_mask.to(self.device, dtype=self._dtype())
|
||||
loss_mask = loss_mask.to(self.device, dtype=self._dtype())
|
||||
labels = labels.to(self.device, dtype=torch.long)
|
||||
|
||||
B, T, P, D = audio_feats.shape
|
||||
feat_embed = self.feat_encoder(audio_feats)
|
||||
feat_embed = self.enc_to_lm_proj(feat_embed)
|
||||
|
||||
scale_emb = getattr(self.config.lm_config, "scale_emb", 1.0)
|
||||
if not getattr(self.config.lm_config, "use_mup", False):
|
||||
scale_emb = 1.0
|
||||
text_embed = self.base_lm.embed_tokens(text_tokens) * scale_emb
|
||||
combined_embed = text_mask.unsqueeze(-1) * text_embed + audio_mask.unsqueeze(-1) * feat_embed
|
||||
|
||||
enc_outputs, _ = self.base_lm(inputs_embeds=combined_embed, is_causal=True)
|
||||
enc_outputs = enc_outputs.to(self._dtype())
|
||||
enc_outputs = self.fsq_layer(enc_outputs) * audio_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||
lm_hidden = torch.cat((torch.zeros_like(enc_outputs[:, 0:1, :]), enc_outputs[:, :-1, :]), dim=1)
|
||||
|
||||
residual_inputs = enc_outputs + audio_mask.unsqueeze(-1) * feat_embed
|
||||
residual_outputs, _ = self.residual_lm(inputs_embeds=residual_inputs, is_causal=True)
|
||||
residual_outputs = residual_outputs.to(self._dtype())
|
||||
residual_hidden = torch.cat(
|
||||
(torch.zeros_like(residual_outputs[:, 0:1, :]), residual_outputs[:, :-1, :]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
dit_hidden = self.lm_to_dit_proj(lm_hidden) + self.res_to_dit_proj(residual_hidden)
|
||||
dit_hidden = rearrange(dit_hidden, "b t c -> (b t) c")
|
||||
|
||||
# Keep diffusion inputs in the same dtype as the model (e.g., bfloat16)
|
||||
target_dtype = self._dtype()
|
||||
|
||||
feat_gt = rearrange(audio_feats.to(target_dtype), "b t p d -> (b t) p d")
|
||||
feat_cond = torch.cat(
|
||||
(torch.zeros_like(audio_feats[:, 0:1, ...]), audio_feats[:, :-1, ...]),
|
||||
dim=1,
|
||||
)
|
||||
feat_cond = rearrange(feat_cond.to(target_dtype), "b t p d -> (b t) p d")
|
||||
|
||||
loss_seq_mask = loss_mask.unsqueeze(-1).repeat(1, 1, self.patch_size)
|
||||
loss_seq_mask = rearrange(loss_seq_mask, "b t p -> (b t) p 1").to(target_dtype)
|
||||
|
||||
diff_loss = self.feat_decoder.compute_loss(
|
||||
feat_gt.transpose(1, 2).contiguous(),
|
||||
dit_hidden,
|
||||
cond=feat_cond.transpose(1, 2).contiguous(),
|
||||
tgt_mask=loss_seq_mask.transpose(1, 2).contiguous(),
|
||||
progress=progress,
|
||||
)
|
||||
|
||||
stop_logits = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden)))
|
||||
stop_losses = self.stop_loss(stop_logits.transpose(1, 2), labels)
|
||||
denom = torch.clamp(loss_mask.sum(), min=1.0)
|
||||
stop_loss = (stop_losses * loss_mask).sum() / denom
|
||||
|
||||
feat_pred = None
|
||||
if sample_generate:
|
||||
feat_cond_for_sample = feat_cond.transpose(1, 2).contiguous()
|
||||
feat_pred_seq = self.feat_decoder(
|
||||
mu=dit_hidden,
|
||||
patch_size=self.patch_size,
|
||||
cond=feat_cond_for_sample,
|
||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||
else 10,
|
||||
)
|
||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
feat_gt_tensor = rearrange(feat_gt, "(b t) p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
return {
|
||||
"loss/diff": diff_loss,
|
||||
"loss/stop": stop_loss,
|
||||
"feat_gt": feat_gt_tensor,
|
||||
"feat_pred": feat_pred,
|
||||
}
|
||||
|
||||
def _dtype(self):
|
||||
return get_dtype(self.config.dtype)
|
||||
|
||||
|
||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||
return next(self._generate(*args, streaming=False, **kwargs))
|
||||
@@ -238,25 +393,25 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
if audio.size(0) > 1:
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
audio = audio.mean(dim=0, keepdim=True)
|
||||
|
||||
if sr != self.sample_rate:
|
||||
audio = torchaudio.functional.resample(audio, sr, self.sample_rate)
|
||||
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# 左填充:在音频开头填充,保持有效音频数据在序列末尾
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# (B, D, T)
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
|
||||
audio_feat = audio_feat.view(
|
||||
self.audio_vae.latent_dim,
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
audio_length = audio_feat.size(0)
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token = torch.cat([text_token, text_pad_token])
|
||||
@@ -288,7 +443,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -314,7 +469,6 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
yield decode_audio
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -331,13 +485,11 @@ class VoxCPMModel(nn.Module):
|
||||
prompt_wav_path: prompt audio path (required)
|
||||
|
||||
Returns:
|
||||
prompt_cache: dict with text tokens and audio features
|
||||
prompt_cache: dict with prompt_text (raw text) and audio features.
|
||||
Text tokenization will be done during generation for consistency.
|
||||
"""
|
||||
if not prompt_text or not prompt_wav_path:
|
||||
raise ValueError("prompt_text and prompt_wav_path are required")
|
||||
|
||||
# build text tokens
|
||||
text_token = torch.LongTensor(self.text_tokenizer(prompt_text))
|
||||
|
||||
# load audio
|
||||
audio, sr = torchaudio.load(prompt_wav_path)
|
||||
@@ -350,7 +502,9 @@ class VoxCPMModel(nn.Module):
|
||||
patch_len = self.patch_size * self.chunk_size
|
||||
|
||||
if audio.size(1) % patch_len != 0:
|
||||
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
|
||||
# Left padding: pad at the beginning of the audio to keep valid audio data at the end of the sequence
|
||||
padding_size = patch_len - audio.size(1) % patch_len
|
||||
audio = torch.nn.functional.pad(audio, (padding_size, 0))
|
||||
|
||||
# extract audio features
|
||||
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
|
||||
@@ -360,10 +514,9 @@ class VoxCPMModel(nn.Module):
|
||||
-1,
|
||||
self.patch_size,
|
||||
).permute(1, 2, 0) # (D, T, P)
|
||||
audio_feat = audio_feat[:-1, ...] # trick: remove the last padding token
|
||||
# build prompt cache
|
||||
# build prompt cache - only save raw text and audio features
|
||||
prompt_cache = {
|
||||
"text_token": text_token,
|
||||
"prompt_text": prompt_text,
|
||||
"audio_feat": audio_feat,
|
||||
}
|
||||
|
||||
@@ -373,7 +526,7 @@ class VoxCPMModel(nn.Module):
|
||||
def merge_prompt_cache(
|
||||
self,
|
||||
original_cache: dict,
|
||||
new_text_token: torch.Tensor,
|
||||
new_text: str,
|
||||
new_audio_feat: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
@@ -381,38 +534,42 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
Args:
|
||||
original_cache: original prompt cache
|
||||
new_text_token: newly generated text tokens
|
||||
new_text: newly generated text
|
||||
new_audio_feat: newly generated audio features
|
||||
|
||||
Returns:
|
||||
merged_cache: merged cache
|
||||
merged_cache: merged cache with prompt_text and audio_feat
|
||||
"""
|
||||
if original_cache is None:
|
||||
return {
|
||||
"text_token": new_text_token,
|
||||
"prompt_text": new_text,
|
||||
"audio_feat": new_audio_feat,
|
||||
}
|
||||
original_text_token = original_cache["text_token"]
|
||||
original_prompt_text = original_cache["prompt_text"]
|
||||
original_audio_feat = original_cache["audio_feat"]
|
||||
merged_text_token = torch.cat([original_text_token, new_text_token], dim=0)
|
||||
# Merge text by concatenation
|
||||
merged_prompt_text = original_prompt_text + new_text
|
||||
merged_audio_feat = torch.cat([original_audio_feat, new_audio_feat], dim=0)
|
||||
|
||||
# build new cache
|
||||
merged_cache = {
|
||||
"text_token": merged_text_token,
|
||||
"prompt_text": merged_prompt_text,
|
||||
"audio_feat": merged_audio_feat,
|
||||
}
|
||||
|
||||
return merged_cache
|
||||
|
||||
|
||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||
|
||||
|
||||
def generate_with_prompt_cache_streaming(
|
||||
self, *args, **kwargs
|
||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _generate_with_prompt_cache(
|
||||
self,
|
||||
@@ -453,14 +610,14 @@ class VoxCPMModel(nn.Module):
|
||||
retry_badcase = False
|
||||
# get prompt from cache
|
||||
if prompt_cache is None:
|
||||
prompt_text_token = torch.empty(0, dtype=torch.int32)
|
||||
prompt_audio_feat = torch.empty((0, self.patch_size, self.audio_vae.latent_dim), dtype=torch.float32)
|
||||
text = target_text
|
||||
else:
|
||||
prompt_text_token = prompt_cache["text_token"]
|
||||
prompt_audio_feat = prompt_cache["audio_feat"]
|
||||
# build target text tokens
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
text_token = torch.cat([prompt_text_token, target_text_token], dim=0)
|
||||
prompt_text = prompt_cache["prompt_text"]
|
||||
text = prompt_text + target_text
|
||||
|
||||
text_token = torch.LongTensor(self.text_tokenizer(text))
|
||||
text_token = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
@@ -472,6 +629,8 @@ class VoxCPMModel(nn.Module):
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
target_text_token = torch.LongTensor(self.text_tokenizer(target_text))
|
||||
|
||||
audio_length = prompt_audio_feat.size(0)
|
||||
text_length = text_token.shape[0]
|
||||
@@ -501,7 +660,7 @@ class VoxCPMModel(nn.Module):
|
||||
audio_feat,
|
||||
audio_mask,
|
||||
min_len=min_len,
|
||||
max_len=int(target_text_length * retry_badcase_ratio_threshold + 10) if retry_badcase else max_len,
|
||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
||||
inference_timesteps=inference_timesteps,
|
||||
cfg_value=cfg_value,
|
||||
streaming=streaming,
|
||||
@@ -530,7 +689,6 @@ class VoxCPMModel(nn.Module):
|
||||
break
|
||||
if not streaming:
|
||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
||||
decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio
|
||||
|
||||
yield (
|
||||
decode_audio,
|
||||
@@ -556,6 +714,7 @@ class VoxCPMModel(nn.Module):
|
||||
inference_timesteps: int = 10,
|
||||
cfg_value: float = 2.0,
|
||||
streaming: bool = False,
|
||||
streaming_prefix_len: int = 3,
|
||||
) -> Generator[Tuple[torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||
"""Core inference method for audio generation.
|
||||
|
||||
@@ -628,7 +787,7 @@ class VoxCPMModel(nn.Module):
|
||||
1, 2
|
||||
) # [b, p, d]
|
||||
|
||||
curr_embed = self.feat_encoder_step(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.feat_encoder(pred_feat.unsqueeze(1)) # b, 1, c
|
||||
curr_embed = self.enc_to_lm_proj(curr_embed)
|
||||
|
||||
pred_feat_seq.append(pred_feat.unsqueeze(1)) # b, 1, p, d
|
||||
@@ -636,8 +795,9 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if streaming:
|
||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-3:], dim=1)
|
||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
|
||||
yield feat_pred, pred_feat_seq
|
||||
|
||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||
@@ -656,35 +816,138 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
if not streaming:
|
||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True):
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
|
||||
audio_vae = AudioVAE()
|
||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||
vae_state_dict = torch.load(
|
||||
os.path.join(path, "audiovae.pth"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
else: # training mode
|
||||
for name, param in model.named_parameters():
|
||||
if "audio_vae" in name: # freeze VAE weights
|
||||
param.requires_grad = False
|
||||
continue
|
||||
if lora_config is not None:
|
||||
if "lora" not in name: # freeze non-LoRA weights
|
||||
param.requires_grad = False
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
model_state_dict = torch.load(
|
||||
os.path.join(path, "pytorch_model.bin"),
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)["state_dict"]
|
||||
|
||||
|
||||
# Try to load from safetensors first, fallback to pytorch_model.bin
|
||||
safetensors_path = os.path.join(path, "model.safetensors")
|
||||
pytorch_model_path = os.path.join(path, "pytorch_model.bin")
|
||||
|
||||
if os.path.exists(safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||
print(f"Loading model from safetensors: {safetensors_path}")
|
||||
model_state_dict = load_file(safetensors_path)
|
||||
elif os.path.exists(pytorch_model_path):
|
||||
print(f"Loading model from pytorch_model.bin: {pytorch_model_path}")
|
||||
checkpoint = torch.load(
|
||||
pytorch_model_path,
|
||||
map_location="cpu",
|
||||
weights_only=True,
|
||||
)
|
||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
||||
)
|
||||
|
||||
for kw, val in vae_state_dict.items():
|
||||
model_state_dict[f"audio_vae.{kw}"] = val
|
||||
model.load_state_dict(model_state_dict, strict=True)
|
||||
|
||||
# LoRALinear holds weight/bias directly, compatible with nn.Linear state_dict keys.
|
||||
# Using strict=False since pretrained weights don't contain lora_A/lora_B.
|
||||
model.load_state_dict(model_state_dict, strict=False)
|
||||
if training:
|
||||
return model
|
||||
return model.to(model.device).eval().optimize(disable=not optimize)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# LoRA Weight Management
|
||||
# ------------------------------------------------------------------ #
|
||||
def _iter_lora_modules(self):
|
||||
"""Iterate over all LoRA modules."""
|
||||
from ..modules.layers.lora import LoRALinear
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
yield module
|
||||
|
||||
def load_lora_weights(self, lora_path: str, device: str = None):
|
||||
"""
|
||||
Load LoRA weights from file, supports calling after torch.compile.
|
||||
Uses named_parameters() to handle compile's _orig_mod wrapper.
|
||||
Supports both safetensors and pytorch formats.
|
||||
|
||||
Args:
|
||||
lora_path: Checkpoint path (directory or .safetensors/.ckpt file)
|
||||
device: Target device, defaults to model's current device
|
||||
Returns:
|
||||
tuple: (loaded_keys, skipped_keys)
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
device = device or self.device
|
||||
lora_path = Path(lora_path)
|
||||
|
||||
# Try safetensors first, then fallback to .ckpt
|
||||
if lora_path.is_dir():
|
||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
||||
else:
|
||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
||||
|
||||
# Load from safetensors if available
|
||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||
state_dict = load_file(str(safetensors_file), device=device)
|
||||
elif ckpt_file and ckpt_file.exists():
|
||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("state_dict", ckpt)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
||||
)
|
||||
|
||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||
model_params = dict(self.named_parameters())
|
||||
key_mapping = {k.replace("._orig_mod.", "."): k for k in model_params if "._orig_mod." in k}
|
||||
|
||||
loaded_keys, skipped_keys = [], []
|
||||
for key, value in state_dict.items():
|
||||
target_key = key if key in model_params else key_mapping.get(key)
|
||||
if target_key:
|
||||
model_params[target_key].data.copy_(value.to(device))
|
||||
loaded_keys.append(key)
|
||||
else:
|
||||
skipped_keys.append(key)
|
||||
|
||||
return loaded_keys, skipped_keys
|
||||
|
||||
def set_lora_enabled(self, enabled: bool):
|
||||
"""Enable/disable all LoRA layers."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.set_enabled(enabled)
|
||||
|
||||
def reset_lora_weights(self):
|
||||
"""Reset all LoRA weights (A: kaiming, B: zeros), effectively unloading LoRA."""
|
||||
for module in self._iter_lora_modules():
|
||||
module.reset_lora_parameters()
|
||||
|
||||
def get_lora_state_dict(self) -> dict:
|
||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||
return {name: param.data.clone()
|
||||
for name, param in self.named_parameters()
|
||||
if "lora_" in name}
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .audio_vae import AudioVAE
|
||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def WNConv1d(*args, **kwargs):
|
||||
@@ -266,6 +267,17 @@ class CausalDecoder(nn.Module):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class AudioVAEConfig(BaseModel):
|
||||
encoder_dim: int = 128
|
||||
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||
latent_dim: int = 64
|
||||
decoder_dim: int = 1536
|
||||
decoder_rates: List[int] = [8, 8, 5, 2]
|
||||
depthwise: bool = True
|
||||
sample_rate: int = 16000
|
||||
use_noise_block: bool = False
|
||||
|
||||
|
||||
class AudioVAE(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
@@ -273,17 +285,23 @@ class AudioVAE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 5, 8, 8],
|
||||
latent_dim: int = 64,
|
||||
decoder_dim: int = 1536,
|
||||
decoder_rates: List[int] = [8, 8, 5, 2],
|
||||
depthwise: bool = True,
|
||||
sample_rate: int = 16000,
|
||||
use_noise_block: bool = False,
|
||||
config: Optional[AudioVAEConfig] = None,
|
||||
):
|
||||
# 如果没有传入config,使用默认配置
|
||||
if config is None:
|
||||
config = AudioVAEConfig()
|
||||
|
||||
super().__init__()
|
||||
|
||||
encoder_dim = config.encoder_dim
|
||||
encoder_rates = config.encoder_rates
|
||||
latent_dim = config.latent_dim
|
||||
decoder_dim = config.decoder_dim
|
||||
decoder_rates = config.decoder_rates
|
||||
depthwise = config.depthwise
|
||||
sample_rate = config.sample_rate
|
||||
use_noise_block = config.use_noise_block
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_rates = encoder_rates
|
||||
self.decoder_dim = decoder_dim
|
||||
|
||||
133
src/voxcpm/modules/layers/lora.py
Normal file
133
src/voxcpm/modules/layers/lora.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
"""
|
||||
LoRA 线性层:直接持有 weight/bias,保持与 nn.Linear 相同的 state_dict key 结构。
|
||||
|
||||
state_dict 结构:
|
||||
- weight: 原始权重(与 nn.Linear 一致)
|
||||
- bias: 原始偏置(与 nn.Linear 一致)
|
||||
- lora_A: LoRA 低秩矩阵 A
|
||||
- lora_B: LoRA 低秩矩阵 B
|
||||
|
||||
这样设计的好处:加载预训练权重时无需做 key 转换。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base: nn.Linear,
|
||||
r: int,
|
||||
alpha: float = 1.0,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(base, nn.Linear), "LoRALinear only supports wrapping nn.Linear."
|
||||
|
||||
self.in_features = base.in_features
|
||||
self.out_features = base.out_features
|
||||
self.r = r
|
||||
self.alpha = alpha
|
||||
self._base_scaling = alpha / r if r > 0 else 0.0
|
||||
|
||||
# 使用 buffer 存储 scaling,这样修改值不会触发 torch.compile 重编译
|
||||
# persistent=False 表示不保存到 state_dict,避免加载时 missing key
|
||||
self.register_buffer("scaling", torch.tensor(self._base_scaling), persistent=False)
|
||||
|
||||
# 直接持有 weight 和 bias(从原始 Linear 转移过来)
|
||||
self.weight = base.weight
|
||||
self.bias = base.bias # 可能是 None
|
||||
|
||||
# LoRA 参数
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
else:
|
||||
self.register_parameter("lora_A", None)
|
||||
self.register_parameter("lora_B", None)
|
||||
|
||||
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# 基础 Linear 计算
|
||||
result = F.linear(x, self.weight, self.bias)
|
||||
if self.r <= 0 or self.lora_A is None:
|
||||
return result
|
||||
# LoRA: result + dropout(x @ A^T @ B^T) * scaling
|
||||
lora_out = F.linear(F.linear(x, self.lora_A), self.lora_B)
|
||||
return result + self.dropout(lora_out) * self.scaling
|
||||
|
||||
def reset_lora_parameters(self):
|
||||
"""重置 LoRA 参数到初始状态"""
|
||||
if self.r > 0 and self.lora_A is not None:
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def set_enabled(self, enabled: bool):
|
||||
"""启用/禁用 LoRA(通过 scaling 控制,兼容 torch.compile)"""
|
||||
# 使用 fill_ 原地修改 buffer 值,不会触发重编译
|
||||
self.scaling.fill_(self._base_scaling if enabled else 0.0)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.scaling.item() != 0.0
|
||||
|
||||
|
||||
def _get_parent_module(root: nn.Module, name: str) -> Optional[nn.Module]:
|
||||
"""
|
||||
根据类似 'layers.0.self_attn.q_proj' 的全名,返回 parent module(即 q_proj 的上一级)。
|
||||
"""
|
||||
parts = name.split(".")
|
||||
if len(parts) == 1:
|
||||
return root
|
||||
parent = root
|
||||
for p in parts[:-1]:
|
||||
if not hasattr(parent, p):
|
||||
return None
|
||||
parent = getattr(parent, p)
|
||||
return parent
|
||||
|
||||
|
||||
def apply_lora_to_named_linear_modules(
|
||||
root: nn.Module,
|
||||
*,
|
||||
target_submodule_names: list[str],
|
||||
r: int,
|
||||
alpha: float,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""
|
||||
在给定模块及其子模块中,对名字以 target_submodule_names 结尾的 Linear 层注入 LoRA。
|
||||
|
||||
例如 target_submodule_names=["q_proj", "v_proj"] 时,
|
||||
会在所有名为 *.q_proj / *.v_proj 的 nn.Linear 上替换为 LoRALinear。
|
||||
"""
|
||||
for full_name, module in list(root.named_modules()):
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
short_name = full_name.split(".")[-1]
|
||||
if short_name not in target_submodule_names:
|
||||
continue
|
||||
|
||||
parent = _get_parent_module(root, full_name)
|
||||
if parent is None:
|
||||
continue
|
||||
|
||||
# 用 LoRALinear 替换原始 Linear
|
||||
lora_layer = LoRALinear(
|
||||
base=module,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
)
|
||||
setattr(parent, short_name, lora_layer)
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,29 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
from torch.func import jvp
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
sigma_min: float = 1e-6
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
training_cfg_rate: float = 0.1
|
||||
inference_cfg_rate: float = 1.0
|
||||
reg_loss_type: str = "l1"
|
||||
ratio_r_neq_t_range: Tuple[float, float] = (0.25, 0.75)
|
||||
noise_cond_prob_range: Tuple[float, float] = (0.0, 0.0)
|
||||
noise_cond_scale: float = 0.0
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
in_channels: int,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
@@ -23,12 +32,21 @@ class UnifiedCFM(torch.nn.Module):
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.training_cfg_rate = cfm_params.training_cfg_rate
|
||||
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
||||
self.reg_loss_type = cfm_params.reg_loss_type
|
||||
self.ratio_r_neq_t_range = cfm_params.ratio_r_neq_t_range
|
||||
self.noise_cond_prob_range = cfm_params.noise_cond_prob_range
|
||||
self.noise_cond_scale = cfm_params.noise_cond_scale
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Inference
|
||||
# ------------------------------------------------------------------ #
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
@@ -41,33 +59,25 @@ class UnifiedCFM(torch.nn.Module):
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
b, _ = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
return self.solve_euler(
|
||||
x=z,
|
||||
t_span=t_span,
|
||||
mu=mu,
|
||||
cond=cond,
|
||||
cfg_value=cfg_value,
|
||||
use_cfg_zero_star=use_cfg_zero_star,
|
||||
)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
def optimized_scale(self, positive_flat: torch.Tensor, negative_flat: torch.Tensor):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
@@ -80,24 +90,13 @@ class UnifiedCFM(torch.nn.Module):
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: condition -- prefix prompt
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
dphi_dt = torch.zeros_like(x)
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
@@ -105,7 +104,7 @@ class UnifiedCFM(torch.nn.Module):
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, cond.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
@@ -135,3 +134,98 @@ class UnifiedCFM(torch.nn.Module):
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Training loss
|
||||
# ------------------------------------------------------------------ #
|
||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||
if mask is not None:
|
||||
weights = weights * mask
|
||||
return weights.detach()
|
||||
|
||||
def sample_r_t(self, x: torch.Tensor, mu: float = -0.4, sigma: float = 1.0, ratio_r_neq_t: float = 0.0):
|
||||
batch_size = x.shape[0]
|
||||
if self.t_scheduler == "log-norm":
|
||||
s_r = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
s_t = torch.randn(batch_size, device=x.device, dtype=x.dtype) * sigma + mu
|
||||
r = torch.sigmoid(s_r)
|
||||
t = torch.sigmoid(s_t)
|
||||
elif self.t_scheduler == "uniform":
|
||||
r = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
t = torch.rand(batch_size, device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported t_scheduler: {self.t_scheduler}")
|
||||
|
||||
mask = torch.rand(batch_size, device=x.device, dtype=x.dtype) < ratio_r_neq_t
|
||||
r, t = torch.where(
|
||||
mask,
|
||||
torch.stack([torch.min(r, t), torch.max(r, t)], dim=0),
|
||||
torch.stack([t, t], dim=0),
|
||||
)
|
||||
|
||||
return r.squeeze(), t.squeeze()
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
x1: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor | None = None,
|
||||
tgt_mask: torch.Tensor | None = None,
|
||||
progress: float = 0.0,
|
||||
):
|
||||
b, _, _ = x1.shape
|
||||
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1)
|
||||
|
||||
if cond is None:
|
||||
cond = torch.zeros_like(x1)
|
||||
|
||||
noisy_mask = torch.rand(b, device=x1.device) > (
|
||||
1.0
|
||||
- (
|
||||
self.noise_cond_prob_range[0]
|
||||
+ progress * (self.noise_cond_prob_range[1] - self.noise_cond_prob_range[0])
|
||||
)
|
||||
)
|
||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||
|
||||
ratio_r_neq_t = (
|
||||
self.ratio_r_neq_t_range[0]
|
||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||
if self.mean_mode
|
||||
else 0.0
|
||||
)
|
||||
|
||||
r, t = self.sample_r_t(x1, ratio_r_neq_t=ratio_r_neq_t)
|
||||
r_ = r.detach().clone()
|
||||
t_ = t.detach().clone()
|
||||
z = torch.randn_like(x1)
|
||||
y = (1 - t_.view(-1, 1, 1)) * x1 + t_.view(-1, 1, 1) * z
|
||||
v = z - x1
|
||||
|
||||
def model_fn(z_sample, r_sample, t_sample):
|
||||
return self.estimator(z_sample, mu, t_sample, cond, dt=t_sample - r_sample)
|
||||
|
||||
if self.mean_mode:
|
||||
v_r = torch.zeros_like(r)
|
||||
v_t = torch.ones_like(t)
|
||||
from torch.backends.cuda import sdp_kernel
|
||||
|
||||
with sdp_kernel(enable_flash=False, enable_mem_efficient=False):
|
||||
u_pred, dudt = jvp(model_fn, (y, r, t), (v, v_r, v_t))
|
||||
u_tgt = v - (t_ - r_).view(-1, 1, 1) * dudt
|
||||
else:
|
||||
u_pred = model_fn(y, r, t)
|
||||
u_tgt = v
|
||||
|
||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||
if tgt_mask is not None:
|
||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
||||
else:
|
||||
loss = losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
28
src/voxcpm/training/__init__.py
Normal file
28
src/voxcpm/training/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Training utilities for VoxCPM fine-tuning.
|
||||
|
||||
This package mirrors the training mechanics used in the minicpm-audio
|
||||
tooling while relying solely on local audio-text datasets managed via
|
||||
the HuggingFace ``datasets`` library.
|
||||
"""
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .tracker import TrainingTracker
|
||||
from .data import (
|
||||
load_audio_text_datasets,
|
||||
HFVoxCPMDataset,
|
||||
build_dataloader,
|
||||
BatchProcessor,
|
||||
)
|
||||
from .state import TrainingState
|
||||
|
||||
__all__ = [
|
||||
"Accelerator",
|
||||
"TrainingTracker",
|
||||
"HFVoxCPMDataset",
|
||||
"BatchProcessor",
|
||||
"TrainingState",
|
||||
"load_audio_text_datasets",
|
||||
"build_dataloader",
|
||||
]
|
||||
|
||||
166
src/voxcpm/training/accelerator.py
Normal file
166
src/voxcpm/training/accelerator.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
|
||||
class Accelerator:
|
||||
"""
|
||||
Simplified accelerator that mirrors the behaviour of the minicpm-audio
|
||||
training utilities. It initializes a distributed process group when
|
||||
``torchrun`` is used and exposes helpers for AMP, gradient scaling and
|
||||
preparing models/dataloaders for DDP.
|
||||
"""
|
||||
|
||||
def __init__(self, amp: bool = False, seed: int = 42):
|
||||
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
if self.world_size > 1 and not dist.is_initialized():
|
||||
dist.init_process_group("nccl", init_method="env://")
|
||||
|
||||
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
||||
self.amp = amp
|
||||
|
||||
# Set random seed to ensure model initialization consistency
|
||||
self._set_seed(seed)
|
||||
|
||||
class DummyScaler:
|
||||
def step(self, optimizer):
|
||||
optimizer.step()
|
||||
|
||||
def scale(self, loss):
|
||||
return loss
|
||||
|
||||
def unscale_(self, optimizer):
|
||||
return optimizer
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||
self.device_ctx = (
|
||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||
)
|
||||
self._ddp_model = None # For no_sync support
|
||||
|
||||
def _set_seed(self, seed: int):
|
||||
"""Set random seed to ensure model initialization consistency across multiple GPUs"""
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def __enter__(self):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if self.device_ctx is not None:
|
||||
self.device_ctx.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def barrier(self):
|
||||
"""Synchronize all processes"""
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.AVG):
|
||||
"""All-reduce tensor across processes"""
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(tensor, op=op)
|
||||
return tensor
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Model helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
||||
model.device = self.device
|
||||
model = model.to(self.device)
|
||||
if self.world_size > 1:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
model = DistributedDataParallel(model, device_ids=[self.local_rank], **kwargs)
|
||||
self._ddp_model = model # Save DDP model reference for no_sync support
|
||||
return model
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_sync(self):
|
||||
"""
|
||||
Context manager to skip gradient synchronization during gradient accumulation.
|
||||
Only used outside the last micro-batch.
|
||||
"""
|
||||
if self._ddp_model is not None:
|
||||
with self._ddp_model.no_sync():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda", self.local_rank)
|
||||
if torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
return torch.device("cpu")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# AMP helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def autocast(self, *args, **kwargs):
|
||||
return torch.amp.autocast("cuda", enabled=self.amp, *args, **kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
def step(self, optimizer: torch.optim.Optimizer):
|
||||
self.scaler.step(optimizer)
|
||||
|
||||
def update(self):
|
||||
self.scaler.update()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Data helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
dataset: typing.Iterable,
|
||||
*,
|
||||
batch_size: int,
|
||||
num_workers: int = 0,
|
||||
shuffle: bool = True,
|
||||
collate_fn=None,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
if self.world_size > 1:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle
|
||||
)
|
||||
shuffle = False
|
||||
else:
|
||||
sampler = None
|
||||
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle if sampler is None else False,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return model.module if hasattr(model, "module") else model
|
||||
|
||||
40
src/voxcpm/training/config.py
Normal file
40
src/voxcpm/training/config.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argbind
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def load_yaml_config(path: str | Path) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a YAML configuration file into a dictionary suitable for argbind.
|
||||
"""
|
||||
path = Path(path)
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Configuration file {path} must contain a top-level mapping.")
|
||||
return data
|
||||
|
||||
|
||||
def parse_args_with_config(config_path: str | Path | None = None):
|
||||
"""
|
||||
Helper to unify CLI arguments and YAML configuration.
|
||||
|
||||
Usage mirrors minicpm-audio:
|
||||
args = parse_args_with_config("conf/voxcpm/finetune.yml")
|
||||
with argbind.scope(args):
|
||||
...
|
||||
"""
|
||||
cli_args = argbind.parse_args()
|
||||
if config_path is None:
|
||||
return cli_args
|
||||
|
||||
yaml_args = load_yaml_config(config_path)
|
||||
with argbind.scope(cli_args):
|
||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||
cli_args.update(yaml_args)
|
||||
return cli_args
|
||||
|
||||
|
||||
214
src/voxcpm/training/data.py
Normal file
214
src/voxcpm/training/data.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import argbind
|
||||
import torch
|
||||
from datasets import Audio, Dataset, DatasetDict, load_dataset
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
from ..model.voxcpm import VoxCPMConfig
|
||||
from ..modules.audiovae import AudioVAE
|
||||
from .packers import AudioFeatureProcessingPacker
|
||||
|
||||
|
||||
DEFAULT_TEXT_COLUMN = "text"
|
||||
DEFAULT_AUDIO_COLUMN = "audio"
|
||||
DEFAULT_ID_COLUMN = "dataset_id"
|
||||
|
||||
|
||||
@argbind.bind()
|
||||
def load_audio_text_datasets(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
text_column: str = DEFAULT_TEXT_COLUMN,
|
||||
audio_column: str = DEFAULT_AUDIO_COLUMN,
|
||||
dataset_id_column: str = DEFAULT_ID_COLUMN,
|
||||
sample_rate: int = 16_000,
|
||||
num_proc: int = 1,
|
||||
) -> Tuple[Dataset, Optional[Dataset]]:
|
||||
data_files = {"train": train_manifest}
|
||||
if val_manifest:
|
||||
data_files["validation"] = val_manifest
|
||||
|
||||
dataset_dict: DatasetDict = load_dataset("json", data_files=data_files)
|
||||
|
||||
def prepare(ds: Dataset) -> Dataset:
|
||||
if audio_column not in ds.column_names:
|
||||
raise ValueError(f"Expected '{audio_column}' column in manifest.")
|
||||
# We cast to Audio to ensure proper handling during training,
|
||||
# but for length calculation we might need raw path or duration if available.
|
||||
# HF datasets usually don't compute duration automatically for 'Audio' column.
|
||||
ds = ds.cast_column(audio_column, Audio(sampling_rate=sample_rate))
|
||||
if audio_column != DEFAULT_AUDIO_COLUMN:
|
||||
ds = ds.rename_column(audio_column, DEFAULT_AUDIO_COLUMN)
|
||||
if text_column != DEFAULT_TEXT_COLUMN:
|
||||
ds = ds.rename_column(text_column, DEFAULT_TEXT_COLUMN)
|
||||
if dataset_id_column and dataset_id_column in ds.column_names:
|
||||
if dataset_id_column != DEFAULT_ID_COLUMN:
|
||||
ds = ds.rename_column(dataset_id_column, DEFAULT_ID_COLUMN)
|
||||
else:
|
||||
ds = ds.add_column(DEFAULT_ID_COLUMN, [0] * len(ds))
|
||||
return ds
|
||||
|
||||
train_ds = prepare(dataset_dict["train"])
|
||||
val_ds = prepare(dataset_dict["validation"]) if "validation" in dataset_dict else None
|
||||
return train_ds, val_ds
|
||||
|
||||
|
||||
def compute_sample_lengths(
|
||||
ds: Dataset,
|
||||
audio_vae_fps: int = 25,
|
||||
patch_size: int = 1,
|
||||
) -> List[int]:
|
||||
"""
|
||||
预估每个样本经过 packer 之后的大致序列长度(text+audio),用于过滤超长样本。
|
||||
|
||||
逻辑与 AudioFeatureProcessingPacker / AudioVAE 一致:
|
||||
- 文本长度: len(text_ids)
|
||||
- 音频长度:
|
||||
duration(s) * audio_vae_fps -> 近似 VAE 帧数 t_vae
|
||||
t_seq = ceil(t_vae / patch_size)
|
||||
- 序列总长约为: text_len + t_seq + 2
|
||||
"""
|
||||
lengths: List[int] = []
|
||||
|
||||
has_duration = "duration" in ds.column_names
|
||||
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
text_len = len(item["text_ids"])
|
||||
|
||||
# 音频时长(尽量不解码;若 manifest 里已有 duration 列则优先使用)
|
||||
if has_duration:
|
||||
duration = float(item["duration"])
|
||||
else:
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
duration = len(audio["array"]) / float(audio["sampling_rate"])
|
||||
|
||||
t_vae = math.ceil(duration * audio_vae_fps)
|
||||
t_seq = math.ceil(t_vae / patch_size)
|
||||
|
||||
total_len = text_len + t_seq + 2
|
||||
lengths.append(total_len)
|
||||
|
||||
return lengths
|
||||
|
||||
|
||||
class HFVoxCPMDataset(TorchDataset):
|
||||
"""
|
||||
Thin wrapper around a tokenized HuggingFace dataset that returns
|
||||
PyTorch-friendly samples.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
audio = item[DEFAULT_AUDIO_COLUMN]
|
||||
return {
|
||||
"text_ids": item["text_ids"],
|
||||
"audio_array": audio["array"],
|
||||
"audio_sampling_rate": audio["sampling_rate"],
|
||||
"dataset_id": item.get(DEFAULT_ID_COLUMN, 0),
|
||||
"is_prompt": item.get("is_prompt", False),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def pad_sequences(seqs: List[torch.Tensor], pad_value: float):
|
||||
if not seqs:
|
||||
return torch.empty(0)
|
||||
max_len = max(seq.shape[0] for seq in seqs)
|
||||
padded = []
|
||||
for seq in seqs:
|
||||
if seq.shape[0] < max_len:
|
||||
pad_width = (0, max_len - seq.shape[0])
|
||||
seq = torch.nn.functional.pad(seq, pad_width, value=pad_value)
|
||||
padded.append(seq)
|
||||
return torch.stack(padded)
|
||||
|
||||
@classmethod
|
||||
def collate_fn(cls, batch: List[Dict]):
|
||||
text_tensors = [torch.tensor(sample["text_ids"], dtype=torch.int32) for sample in batch]
|
||||
audio_tensors = [torch.tensor(sample["audio_array"], dtype=torch.float32) for sample in batch]
|
||||
dataset_ids = torch.tensor([sample["dataset_id"] for sample in batch], dtype=torch.int32)
|
||||
is_prompts = [bool(sample.get("is_prompt", False)) for sample in batch]
|
||||
|
||||
text_padded = cls.pad_sequences(text_tensors, pad_value=-100)
|
||||
audio_padded = cls.pad_sequences(audio_tensors, pad_value=-100.0)
|
||||
task_ids = torch.ones(text_padded.size(0), dtype=torch.int32)
|
||||
|
||||
return {
|
||||
"text_tokens": text_padded,
|
||||
"audio_tokens": audio_padded,
|
||||
"task_ids": task_ids,
|
||||
"dataset_ids": dataset_ids,
|
||||
"is_prompts": is_prompts,
|
||||
}
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""
|
||||
Wraps ``AudioFeatureProcessingPacker`` so the training loop can mirror
|
||||
the minicpm-audio mechanics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: VoxCPMConfig,
|
||||
audio_vae: AudioVAE,
|
||||
dataset_cnt: int,
|
||||
device: torch.device,
|
||||
):
|
||||
self.device = device
|
||||
self.dataset_cnt = dataset_cnt
|
||||
self.audio_vae = audio_vae
|
||||
self.audio_vae.to(device)
|
||||
self.packer = AudioFeatureProcessingPacker(
|
||||
dataset_cnt=dataset_cnt,
|
||||
max_len=config.max_length,
|
||||
patch_size=config.patch_size,
|
||||
feat_dim=config.feat_dim,
|
||||
audio_vae=self.audio_vae,
|
||||
)
|
||||
|
||||
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
audio_tokens = batch["audio_tokens"].to(self.device)
|
||||
text_tokens = batch["text_tokens"].to(self.device)
|
||||
task_ids = batch["task_ids"].to(self.device)
|
||||
dataset_ids = batch["dataset_ids"].to(self.device)
|
||||
|
||||
packed = self.packer(
|
||||
audio_tokens=audio_tokens,
|
||||
text_tokens=text_tokens,
|
||||
task_ids=task_ids,
|
||||
dataset_ids=dataset_ids,
|
||||
is_prompts=batch["is_prompts"],
|
||||
)
|
||||
return packed
|
||||
|
||||
|
||||
def build_dataloader(
|
||||
hf_dataset: Dataset,
|
||||
*,
|
||||
accelerator,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
drop_last: bool = False,
|
||||
) -> torch.utils.data.DataLoader:
|
||||
torch_dataset = HFVoxCPMDataset(hf_dataset)
|
||||
# Standard padding-based batching; Accelerator will attach DistributedSampler if needed.
|
||||
return accelerator.prepare_dataloader(
|
||||
torch_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=True,
|
||||
collate_fn=HFVoxCPMDataset.collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
|
||||
289
src/voxcpm/training/packers.py
Normal file
289
src/voxcpm/training/packers.py
Normal file
@@ -0,0 +1,289 @@
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class AudioFeatureProcessingPacker:
|
||||
"""
|
||||
Adapted from the minicpm-audio training utilities. It converts raw text and
|
||||
audio tokens into the packed multimodal representation required by VoxCPM.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_cnt: int, max_len: int, patch_size: int, feat_dim: int, audio_vae: nn.Module):
|
||||
self.audio_start_id = 101
|
||||
self.audio_end_id = 102
|
||||
# unused now
|
||||
self.audio_prompt_start_id = 103
|
||||
self.audio_prompt_end_id = 104
|
||||
self.text_eos_token_id = 2
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_len = audio_vae.hop_length * self.patch_size
|
||||
self.feat_dim = feat_dim
|
||||
self.dataset_cnt = max(dataset_cnt, 1)
|
||||
self.max_len = max_len
|
||||
|
||||
self.audio_vae = audio_vae
|
||||
|
||||
self.process_functions = {"tts": self.process_tts_data}
|
||||
self.task_id_map = {"tts": 1}
|
||||
self.id_to_task = {idx: usage for usage, idx in self.task_id_map.items()}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _first_pad_position(tokens: torch.Tensor):
|
||||
positions = (tokens == -100).nonzero(as_tuple=True)
|
||||
if positions[0].numel() == 0:
|
||||
return None
|
||||
return int(positions[0][0])
|
||||
|
||||
def unpad_text_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def unpad_audio_tokens(self, tokens: torch.Tensor):
|
||||
pad_pos = self._first_pad_position(tokens)
|
||||
return tokens if pad_pos is None else tokens[:pad_pos]
|
||||
|
||||
def encode_audio(self, wav: torch.Tensor):
|
||||
"""
|
||||
Encode raw waveform into latent features using AudioVAE.
|
||||
|
||||
AudioVAE.encode expects shape [B, 1, T'] and returns [B, D, T].
|
||||
We then transpose to [B, T, D] to match downstream expectations.
|
||||
"""
|
||||
wav = wav.unsqueeze(0) # [1, T]
|
||||
wav = wav.unsqueeze(1) # [1, 1, T]
|
||||
wav_len = wav.size(-1)
|
||||
if wav_len % self.patch_len != 0:
|
||||
padding_size = self.patch_len - wav_len % self.patch_len
|
||||
wav = torch.nn.functional.pad(wav, (0, padding_size))
|
||||
|
||||
with torch.no_grad():
|
||||
z = self.audio_vae.encode(wav, self.audio_vae.sample_rate) # [1, D, T']
|
||||
feat = z.transpose(1, 2) # [1, T', D]
|
||||
return feat
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Main entry point
|
||||
# ------------------------------------------------------------------ #
|
||||
def __call__(
|
||||
self,
|
||||
audio_tokens: torch.Tensor,
|
||||
text_tokens: torch.Tensor,
|
||||
task_ids: torch.Tensor,
|
||||
dataset_ids: torch.Tensor,
|
||||
is_prompts: List[bool],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Padding-based batching: each sample in the input batch is processed
|
||||
independently and then padded to a common length (capped by ``max_len``).
|
||||
The result tensors all have shape [B, T, ...].
|
||||
"""
|
||||
device = audio_tokens.device
|
||||
max_dataset_id = int(dataset_ids.max().item()) if dataset_ids.numel() > 0 else -1
|
||||
dataset_cnt = max(self.dataset_cnt, max_dataset_id + 1)
|
||||
|
||||
text_tokens_list: List[torch.Tensor] = []
|
||||
audio_feats_list: List[torch.Tensor] = []
|
||||
text_mask_list: List[torch.Tensor] = []
|
||||
audio_mask_list: List[torch.Tensor] = []
|
||||
loss_mask_list: List[torch.Tensor] = []
|
||||
labels_list: List[torch.Tensor] = []
|
||||
audio_task_ids_list: List[torch.Tensor] = []
|
||||
audio_dataset_ids_list: List[torch.Tensor] = []
|
||||
lengths: List[int] = []
|
||||
|
||||
audio_duration_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
text_token_consumed = torch.zeros(dataset_cnt, dtype=torch.float32, device=device)
|
||||
|
||||
for audio_token, text_token, task_id, dataset_idx, is_prompt in zip(
|
||||
audio_tokens, text_tokens, task_ids.tolist(), dataset_ids.tolist(), is_prompts
|
||||
):
|
||||
unpad_audio_token = self.unpad_audio_tokens(audio_token).to(torch.float32)
|
||||
unpad_text_token = self.unpad_text_tokens(text_token)
|
||||
usage = self.id_to_task[task_id]
|
||||
|
||||
(
|
||||
packed_text,
|
||||
audio_feat,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
) = self.process_functions[usage](unpad_audio_token, unpad_text_token, is_prompt)
|
||||
|
||||
audio_duration_consumed[dataset_idx] += audio_duration
|
||||
text_token_consumed[dataset_idx] += text_token_count
|
||||
|
||||
audio_task_id = torch.zeros_like(audio_mask)
|
||||
audio_task_id[audio_mask == 1] = self.task_id_map[usage]
|
||||
|
||||
audio_dataset_id = torch.zeros_like(audio_mask)
|
||||
audio_dataset_id[audio_mask == 1] = dataset_idx + 1
|
||||
|
||||
text_tokens_list.append(packed_text)
|
||||
text_mask_list.append(text_mask)
|
||||
audio_feats_list.append(audio_feat)
|
||||
audio_mask_list.append(audio_mask)
|
||||
loss_mask_list.append(loss_mask)
|
||||
labels_list.append(labels)
|
||||
audio_task_ids_list.append(audio_task_id)
|
||||
audio_dataset_ids_list.append(audio_dataset_id)
|
||||
lengths.append(packed_text.shape[0])
|
||||
|
||||
# Determine padded length per batch (cap by self.max_len)
|
||||
if lengths:
|
||||
max_len = min(self.max_len, max(lengths))
|
||||
else:
|
||||
max_len = self.max_len
|
||||
|
||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
|
||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: [T, P, D]
|
||||
if x.size(0) >= max_len:
|
||||
return x[: max_len]
|
||||
pad = torch.zeros(
|
||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
||||
)
|
||||
return torch.cat([x, pad], dim=0)
|
||||
if lengths:
|
||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
||||
audio_task_ids_batch = torch.stack(
|
||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
||||
)
|
||||
audio_dataset_ids_batch = torch.stack(
|
||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
||||
)
|
||||
|
||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||
position_ids_list = []
|
||||
for L in lengths:
|
||||
L_clip = min(L, max_len)
|
||||
pos = torch.arange(0, L_clip, device=device)
|
||||
if L_clip < max_len:
|
||||
pad = torch.zeros(max_len - L_clip, dtype=pos.dtype, device=device)
|
||||
pos = torch.cat([pos, pad], dim=0)
|
||||
position_ids_list.append(pos)
|
||||
position_ids = torch.stack(position_ids_list, dim=0)
|
||||
else:
|
||||
# Empty batch fallback (shouldn't really happen)
|
||||
text_tokens_batch = torch.zeros((0, self.max_len), dtype=torch.int32, device=device)
|
||||
text_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_feats_batch = torch.zeros(
|
||||
(0, self.max_len, self.patch_size, self.feat_dim), dtype=torch.float32, device=device
|
||||
)
|
||||
audio_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
loss_mask_batch = torch.zeros_like(text_tokens_batch)
|
||||
labels_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_task_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
audio_dataset_ids_batch = torch.zeros_like(text_tokens_batch)
|
||||
position_ids = torch.zeros_like(text_tokens_batch)
|
||||
|
||||
audio_duration_consumed = audio_duration_consumed.to(torch.long)
|
||||
text_token_consumed = text_token_consumed.to(torch.long)
|
||||
|
||||
return {
|
||||
"text_tokens": text_tokens_batch,
|
||||
"audio_feats": audio_feats_batch,
|
||||
"text_mask": text_mask_batch,
|
||||
"audio_mask": audio_mask_batch,
|
||||
"loss_mask": loss_mask_batch,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels_batch,
|
||||
"audio_task_ids": audio_task_ids_batch,
|
||||
"audio_dataset_ids": audio_dataset_ids_batch,
|
||||
"audio_duration_consumed": audio_duration_consumed,
|
||||
"text_token_consumed": text_token_consumed,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Feature extraction helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def extract_audio_feats(self, audio_data: torch.Tensor):
|
||||
audio_feats = self.encode_audio(audio_data)
|
||||
if audio_feats.size(1) % self.patch_size != 0:
|
||||
audio_feats_ = audio_feats.transpose(1, 2)
|
||||
padding = nn.functional.pad(audio_feats_, (0, self.patch_size - audio_feats.size(1) % self.patch_size))
|
||||
audio_feats = padding.transpose(1, 2)
|
||||
|
||||
audio_duration = audio_feats.size(1) / 25
|
||||
audio_feats = rearrange(audio_feats, "b (t p) c -> b t p c", p=self.patch_size)
|
||||
return audio_feats, audio_duration
|
||||
|
||||
def process_tts_data(self, audio_token: torch.Tensor, text_token: torch.Tensor, is_prompt: bool = False):
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_start_id if is_prompt else self.audio_start_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
text_token_count = len(text_token)
|
||||
text_length = text_token_info.shape[0]
|
||||
audio_feat_info, audio_duration = self.extract_audio_feats(audio_token)
|
||||
audio_feat_info = audio_feat_info.squeeze(0)
|
||||
audio_length = audio_feat_info.shape[0]
|
||||
|
||||
text_pad_token = torch.zeros(audio_length, dtype=torch.int32, device=text_token.device)
|
||||
text_token_info = torch.cat(
|
||||
[
|
||||
text_token_info,
|
||||
text_pad_token,
|
||||
torch.tensor(
|
||||
[self.audio_prompt_end_id if is_prompt else self.audio_end_id],
|
||||
dtype=torch.int32,
|
||||
device=text_token.device,
|
||||
),
|
||||
]
|
||||
)
|
||||
audio_pad_feat = torch.zeros(
|
||||
(text_length, self.patch_size, audio_feat_info.size(-1)),
|
||||
dtype=torch.float32,
|
||||
device=text_token.device,
|
||||
)
|
||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||
|
||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
||||
text_token.device
|
||||
)
|
||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
||||
torch.int32
|
||||
).to(text_token.device)
|
||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
||||
|
||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||
labels[-2] = 1
|
||||
|
||||
return (
|
||||
text_token_info,
|
||||
audio_feat_info,
|
||||
text_mask,
|
||||
audio_mask,
|
||||
loss_mask,
|
||||
labels,
|
||||
audio_duration,
|
||||
text_token_count,
|
||||
)
|
||||
|
||||
21
src/voxcpm/training/state.py
Normal file
21
src/voxcpm/training/state.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingState:
|
||||
"""
|
||||
Container that mirrors the object returned in the minicpm-audio training
|
||||
loop. It holds persistent references to the model, optimizer, scheduler,
|
||||
dataloaders and tracker.
|
||||
"""
|
||||
|
||||
generator: object
|
||||
optimizer: object
|
||||
scheduler: object
|
||||
train_loader: object
|
||||
val_loader: object
|
||||
tracker: object
|
||||
batch_processor: object
|
||||
|
||||
78
src/voxcpm/training/tracker.py
Normal file
78
src/voxcpm/training/tracker.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class TrainingTracker:
|
||||
"""
|
||||
Lightweight tracker inspired by the minimcpm-audio training workflow.
|
||||
|
||||
It keeps track of the current global step, prints rank-aware messages,
|
||||
optionally writes to TensorBoard via a provided writer, and stores progress
|
||||
in a logfile for later inspection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
writer=None,
|
||||
log_file: Optional[str] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
self.writer = writer
|
||||
self.log_file = Path(log_file) if log_file else None
|
||||
if self.log_file:
|
||||
self.log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.rank = rank
|
||||
self.step = 0
|
||||
# Record the time of the last log to calculate the interval
|
||||
self._last_log_time: float | None = None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Logging helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
def print(self, message: str):
|
||||
if self.rank == 0:
|
||||
print(message, flush=True)
|
||||
if self.log_file:
|
||||
with self.log_file.open("a", encoding="utf-8") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, float], split: str):
|
||||
if self.rank == 0:
|
||||
now = time.time()
|
||||
dt_str = ""
|
||||
if self._last_log_time is not None:
|
||||
dt = now - self._last_log_time
|
||||
dt_str = f", log interval: {dt:.2f}s"
|
||||
self._last_log_time = now
|
||||
|
||||
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items())
|
||||
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}")
|
||||
if self.writer is not None:
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.writer.add_scalar(f"{split}/{key}", value, self.step)
|
||||
|
||||
def done(self, split: str, message: str):
|
||||
self.print(f"[{split}] {message}")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# State dict
|
||||
# ------------------------------------------------------------------ #
|
||||
def state_dict(self):
|
||||
return {"step": self.step}
|
||||
|
||||
def load_state_dict(self, state):
|
||||
self.step = int(state.get("step", 0))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Context manager compatibility (for parity with minicpm-audio code)
|
||||
# ------------------------------------------------------------------ #
|
||||
@contextlib.contextmanager
|
||||
def live(self):
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user