diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 3af0af9..0d8f1c2 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -283,8 +283,10 @@ class VoxCPMModel(nn.Module): else: break else: - break - return self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() + + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() + decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio + return decode_audio @torch.inference_mode() def build_prompt_cache( @@ -468,7 +470,8 @@ class VoxCPMModel(nn.Module): else: break decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() - + decode_audio = decode_audio[..., 640:-640] # trick: trim the start and end of the audio + return ( decode_audio, target_text_token, @@ -580,7 +583,6 @@ class VoxCPMModel(nn.Module): pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size) - feat_pred = feat_pred[..., 1:-1] # trick: remove the first and last token return feat_pred, pred_feat_seq.squeeze(0).cpu() @classmethod