FX: Add MPS support

This commit is contained in:
刘鑫
2025-09-28 21:06:35 +08:00
parent fbf8984d4e
commit 2eb4d39719
2 changed files with 22 additions and 8 deletions

View File

@@ -85,11 +85,15 @@ class VoxCPMModel(nn.Module):
self.patch_size = config.patch_size self.patch_size = config.patch_size
self.device = config.device self.device = config.device
if not torch.cuda.is_available(): if not torch.cuda.is_available():
self.device = "cpu" if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
# Text-Semantic LM # Text-Semantic LM
self.base_lm = MiniCPMModel(config.lm_config) self.base_lm = MiniCPMModel(config.lm_config)
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer) self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
self.audio_start_token = 101 self.audio_start_token = 101
@@ -100,7 +104,7 @@ class VoxCPMModel(nn.Module):
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
residual_lm_config.vocab_size = 0 residual_lm_config.vocab_size = 0
self.residual_lm = MiniCPMModel(residual_lm_config) self.residual_lm = MiniCPMModel(residual_lm_config)
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype)) self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
# Local Encoder # Local Encoder
encoder_config = config.lm_config.model_copy(deep=True) encoder_config = config.lm_config.model_copy(deep=True)
@@ -132,7 +136,7 @@ class VoxCPMModel(nn.Module):
config.lm_config.hidden_size, config.lm_config.hidden_size,
config.scalar_quantization_latent_dim, config.scalar_quantization_latent_dim,
config.scalar_quantization_scale config.scalar_quantization_scale
) )
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size) self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim) self.res_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
@@ -271,7 +275,7 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device) text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device) audio_mask = audio_mask.unsqueeze(0).to(self.device)
target_text_length = len(self.text_tokenizer(target_text)) target_text_length = len(self.text_tokenizer(target_text))
@@ -484,7 +488,7 @@ class VoxCPMModel(nn.Module):
text_token = text_token.unsqueeze(0).to(self.device) text_token = text_token.unsqueeze(0).to(self.device)
text_mask = text_mask.unsqueeze(0).to(self.device) text_mask = text_mask.unsqueeze(0).to(self.device)
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16) audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
audio_mask = audio_mask.unsqueeze(0).to(self.device) audio_mask = audio_mask.unsqueeze(0).to(self.device)
# run inference # run inference
@@ -670,7 +674,7 @@ class VoxCPMModel(nn.Module):
)["state_dict"] )["state_dict"]
model = cls(config, tokenizer, audio_vae) model = cls(config, tokenizer, audio_vae)
lm_dtype = get_dtype(config.dtype) lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype) model = model.to(lm_dtype)
model.audio_vae = model.audio_vae.to(torch.float32) model.audio_vae = model.audio_vae.to(torch.float32)

View File

@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_cache = key_cache.contiguous()
value_cache = value_cache.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_cache, key_cache,