From 2eb4d3971941a97954c6ad9e00570f081b22c302 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E9=91=AB?= Date: Sun, 28 Sep 2025 21:06:35 +0800 Subject: [PATCH] FX: Add MPS support --- src/voxcpm/model/voxcpm.py | 18 +++++++++++------- src/voxcpm/modules/minicpm4/model.py | 12 +++++++++++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 89b895b..4683d79 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -85,11 +85,15 @@ class VoxCPMModel(nn.Module): self.patch_size = config.patch_size self.device = config.device if not torch.cuda.is_available(): - self.device = "cpu" + if torch.backends.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" + print(f"Running on device: {self.device}, dtype: {self.config.dtype}") # Text-Semantic LM self.base_lm = MiniCPMModel(config.lm_config) - self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) + self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) self.audio_start_token = 101 @@ -100,7 +104,7 @@ class VoxCPMModel(nn.Module): residual_lm_config.num_hidden_layers = config.residual_lm_num_layers residual_lm_config.vocab_size = 0 self.residual_lm = MiniCPMModel(residual_lm_config) - self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) + self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype)) # Local Encoder encoder_config = config.lm_config.model_copy(deep=True) @@ -132,7 +136,7 @@ class VoxCPMModel(nn.Module): config.lm_config.hidden_size, config.scalar_quantization_latent_dim, config.scalar_quantization_scale - ) + ) self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size) self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) @@ -271,7 +275,7 @@ class VoxCPMModel(nn.Module): text_token = text_token.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device) - audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) + audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype)) audio_mask = audio_mask.unsqueeze(0).to(self.device) target_text_length = len(self.text_tokenizer(target_text)) @@ -484,7 +488,7 @@ class VoxCPMModel(nn.Module): text_token = text_token.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device) - audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) + audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype)) audio_mask = audio_mask.unsqueeze(0).to(self.device) # run inference @@ -670,7 +674,7 @@ class VoxCPMModel(nn.Module): )["state_dict"] model = cls(config, tokenizer, audio_vae) - lm_dtype = get_dtype(config.dtype) + lm_dtype = get_dtype(model.config.dtype) model = model.to(lm_dtype) model.audio_vae = model.audio_vae.to(torch.float32) diff --git a/src/voxcpm/modules/minicpm4/model.py b/src/voxcpm/modules/minicpm4/model.py index 58bce05..8945a46 100644 --- a/src/voxcpm/modules/minicpm4/model.py +++ b/src/voxcpm/modules/minicpm4/model.py @@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module): cos, sin = position_emb query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - + + # ref: https://github.com/pytorch/pytorch/issues/163597 + # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module): attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id + # ref: https://github.com/pytorch/pytorch/issues/163597 + # there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous + query_states = query_states.contiguous() + key_cache = key_cache.contiguous() + value_cache = value_cache.contiguous() attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_cache,