FX: Add MPS support
This commit is contained in:
@@ -85,11 +85,15 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
self.device = config.device
|
self.device = config.device
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
self.device = "mps"
|
||||||
|
else:
|
||||||
self.device = "cpu"
|
self.device = "cpu"
|
||||||
|
print(f"Running on device: {self.device}, dtype: {self.config.dtype}")
|
||||||
|
|
||||||
# Text-Semantic LM
|
# Text-Semantic LM
|
||||||
self.base_lm = MiniCPMModel(config.lm_config)
|
self.base_lm = MiniCPMModel(config.lm_config)
|
||||||
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
self.base_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||||
|
|
||||||
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
self.text_tokenizer = mask_multichar_chinese_tokens(tokenizer)
|
||||||
self.audio_start_token = 101
|
self.audio_start_token = 101
|
||||||
@@ -100,7 +104,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
residual_lm_config.num_hidden_layers = config.residual_lm_num_layers
|
||||||
residual_lm_config.vocab_size = 0
|
residual_lm_config.vocab_size = 0
|
||||||
self.residual_lm = MiniCPMModel(residual_lm_config)
|
self.residual_lm = MiniCPMModel(residual_lm_config)
|
||||||
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(config.dtype))
|
self.residual_lm.setup_cache(1, config.max_length, self.device, get_dtype(self.config.dtype))
|
||||||
|
|
||||||
# Local Encoder
|
# Local Encoder
|
||||||
encoder_config = config.lm_config.model_copy(deep=True)
|
encoder_config = config.lm_config.model_copy(deep=True)
|
||||||
@@ -271,7 +275,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
text_token = text_token.unsqueeze(0).to(self.device)
|
text_token = text_token.unsqueeze(0).to(self.device)
|
||||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
target_text_length = len(self.text_tokenizer(target_text))
|
target_text_length = len(self.text_tokenizer(target_text))
|
||||||
@@ -484,7 +488,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
text_token = text_token.unsqueeze(0).to(self.device)
|
text_token = text_token.unsqueeze(0).to(self.device)
|
||||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||||
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(torch.bfloat16)
|
audio_feat = audio_feat.unsqueeze(0).to(self.device).to(get_dtype(self.config.dtype))
|
||||||
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
audio_mask = audio_mask.unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# run inference
|
# run inference
|
||||||
@@ -670,7 +674,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
)["state_dict"]
|
)["state_dict"]
|
||||||
|
|
||||||
model = cls(config, tokenizer, audio_vae)
|
model = cls(config, tokenizer, audio_vae)
|
||||||
lm_dtype = get_dtype(config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,11 @@ class MiniCPMAttention(nn.Module):
|
|||||||
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_states = key_states.contiguous()
|
||||||
|
value_states = value_states.contiguous()
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
|||||||
|
|
||||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||||
|
|
||||||
|
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||||
|
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||||
|
query_states = query_states.contiguous()
|
||||||
|
key_cache = key_cache.contiguous()
|
||||||
|
value_cache = value_cache.contiguous()
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_cache,
|
key_cache,
|
||||||
|
|||||||
Reference in New Issue
Block a user