capture torch compile error

This commit is contained in:
刘鑫
2025-09-17 18:09:09 +08:00
parent 5390a47862
commit 032c7fe403

View File

@@ -148,12 +148,15 @@ class VoxCPMModel(nn.Module):
def optimize(self):
if self.device == "cuda":
try:
if self.device != "cuda":
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
self.feat_encoder_step = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
else:
except:
print("VoxCPMModel can not be optimized by torch.compile, using original forward_step functions")
self.base_lm.forward_step = self.base_lm.forward_step
self.residual_lm.forward_step = self.residual_lm.forward_step
self.feat_encoder_step = self.feat_encoder