mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 11:58:11 +00:00
capture torch compile error
This commit is contained in:
@@ -148,12 +148,15 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
|
||||
def optimize(self):
|
||||
if self.device == "cuda":
|
||||
try:
|
||||
if self.device != "cuda":
|
||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
except:
|
||||
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
|
||||
self.base_lm.forward_step = self.base_lm.forward_step
|
||||
self.residual_lm.forward_step = self.residual_lm.forward_step
|
||||
self.feat_encoder_step = self.feat_encoder
|
||||
|
||||
Reference in New Issue
Block a user