mirror of
https://github.com/OpenBMB/VoxCPM
synced 2025-12-12 03:48:12 +00:00
FX: Add MPS support
This commit is contained in:
@@ -85,11 +85,15 @@ class VoxCPMModel(nn.Module):
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
||||
|
||||
# Text-Semantic LM
|
||||
self.base_lm = MiniCPMModel(config.lm_config)
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||
self.audio_start_token = 101
|
||||
@@ -100,7 +104,7 @@ class VoxCPMModel(nn.Module):
|
||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||
residual_lm_config.vocab_size = 0
|
||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||
|
||||
# Local Encoder
|
||||
encoder_config = config.lm_config.model_copy(deep=True)
|
||||
@@ -271,7 +275,7 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
target_text_length = len(self.text_tokenizer(target_text))
|
||||
@@ -484,7 +488,7 @@ class VoxCPMModel(nn.Module):
|
||||
|
||||
text_token = text_token.unsqueeze(0).to(self.device)
|
||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||
|
||||
# run inference
|
||||
@@ -670,7 +674,7 @@ class VoxCPMModel(nn.Module):
|
||||
)["state_dict"]
|
||||
|
||||
model = cls(config, tokenizer, audio_vae)
|
||||
lm_dtype = get_dtype(config.dtype)
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||
|
||||
|
||||
@@ -154,6 +154,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_cache = key_cache.contiguous()
|
||||
value_cache = value_cache.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
|
||||
Reference in New Issue
Block a user