mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-13 20:18:11 +00:00
init
This commit is contained in:
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
137
src/voxcpm/modules/locdit/unified_cfm.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import torch
|
||||
from typing import List
|
||||
from .local_dit import VoxCPMLocDiT
|
||||
import math
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CfmConfig(BaseModel):
|
||||
sigma_min: float = 1e-06
|
||||
solver: str = "euler"
|
||||
t_scheduler: str = "log-norm"
|
||||
|
||||
|
||||
class UnifiedCFM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cfm_params: CfmConfig,
|
||||
estimator: VoxCPMLocDiT,
|
||||
mean_mode: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.solver = cfm_params.solver
|
||||
self.sigma_min = cfm_params.sigma_min
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.in_channels = in_channels
|
||||
self.mean_mode = mean_mode
|
||||
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
n_timesteps: int,
|
||||
patch_size: int,
|
||||
cond: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
cfg_value: float = 1.0,
|
||||
sway_sampling_coef: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
cond: Not used but kept for future purposes
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, c = mu.shape
|
||||
t = patch_size
|
||||
z = torch.randn((b, self.in_channels, t), device=mu.device, dtype=mu.dtype) * temperature
|
||||
|
||||
t_span = torch.linspace(1, 0, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
# Sway sampling strategy
|
||||
t_span = t_span + sway_sampling_coef * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
|
||||
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, cond=cond, cfg_value=cfg_value, use_cfg_zero_star=use_cfg_zero_star)
|
||||
|
||||
def optimized_scale(self, positive_flat, negative_flat):
|
||||
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
||||
|
||||
st_star = dot_product / squared_norm
|
||||
return st_star
|
||||
|
||||
def solve_euler(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t_span: torch.Tensor,
|
||||
mu: torch.Tensor,
|
||||
cond: torch.Tensor,
|
||||
cfg_value: float = 1.0,
|
||||
use_cfg_zero_star: bool = True,
|
||||
):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats)
|
||||
cond: Not used but kept for future purposes
|
||||
cfg_value (float, optional): cfg value for guidance. Defaults to 1.0.
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[0] - t_span[1]
|
||||
|
||||
sol = []
|
||||
zero_init_steps = max(1, int(len(t_span) * 0.04))
|
||||
for step in range(1, len(t_span)):
|
||||
if use_cfg_zero_star and step <= zero_init_steps:
|
||||
dphi_dt = 0.
|
||||
else:
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
b = x.size(0)
|
||||
x_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
mu_in = torch.zeros([2 * b, mu.size(1)], device=x.device, dtype=x.dtype)
|
||||
t_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
dt_in = torch.zeros([2 * b], device=x.device, dtype=x.dtype)
|
||||
cond_in = torch.zeros([2 * b, self.in_channels, x.size(2)], device=x.device, dtype=x.dtype)
|
||||
x_in[:b], x_in[b:] = x, x
|
||||
mu_in[:b] = mu
|
||||
t_in[:b], t_in[b:] = t.unsqueeze(0), t.unsqueeze(0)
|
||||
dt_in[:b], dt_in[b:] = dt.unsqueeze(0), dt.unsqueeze(0)
|
||||
# not used now
|
||||
if not self.mean_mode:
|
||||
dt_in = torch.zeros_like(dt_in)
|
||||
cond_in[:b], cond_in[b:] = cond, cond
|
||||
|
||||
dphi_dt = self.estimator(x_in, mu_in, t_in, cond_in, dt_in)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
|
||||
if use_cfg_zero_star:
|
||||
positive_flat = dphi_dt.view(b, -1)
|
||||
negative_flat = cfg_dphi_dt.view(b, -1)
|
||||
st_star = self.optimized_scale(positive_flat, negative_flat)
|
||||
st_star = st_star.view(b, *([1] * (len(dphi_dt.shape) - 1)))
|
||||
else:
|
||||
st_star = 1.0
|
||||
|
||||
dphi_dt = cfg_dphi_dt * st_star + cfg_value * (dphi_dt - cfg_dphi_dt * st_star)
|
||||
|
||||
x = x - dt * dphi_dt
|
||||
t = t - dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t - t_span[step + 1]
|
||||
|
||||
return sol[-1]
|
||||
Reference in New Issue
Block a user