Remove segment text logic

This commit is contained in:
刘鑫
2025-09-18 12:01:26 +08:00
parent 1fa9e2ca02
commit e5bcb735f0
3 changed files with 37 additions and 35 deletions

View File

@@ -151,11 +151,16 @@ class VoxCPMModel(nn.Module):
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
try:
import triton
except:
raise ValueError("triton is not installed")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
except:
except Exception as e:
print(e)
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
@@ -317,7 +322,7 @@ class VoxCPMModel(nn.Module):
audio = torch.nn.functional.pad(audio, (0, patch_len - audio.size(1) % patch_len))
# extract audio features
audio_feat = self.audio_vae.encode(audio.cuda(), self.sample_rate).cpu()
audio_feat = self.audio_vae.encode(audio.to(self.device), self.sample_rate).cpu()
audio_feat = audio_feat.view(
self.audio_vae.latent_dim,