diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index b8a58fb..7268704 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -148,12 +148,15 @@ class VoxCPMModel(nn.Module): def optimize(self): - if self.device == "cuda": + try: + if self.device != "cuda": + raise ValueError("VoxCPMModel can only be optimized on CUDA device") self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True) self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True) self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True) - else: + except: + print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions") self.base_lm.forward_step = self.base_lm.forward_step self.residual_lm.forward_step = self.residual_lm.forward_step self.feat_encoder_step = self.feat_encoder