FX: Add MPS support
This commit is contained in:
@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
|
||||
cos, sin = position_emb
|
||||
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_states = key_states.contiguous()
|
||||
value_states = value_states.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
|
||||
|
||||
# ref: https://github.com/pytorch/pytorch/issues/163597
|
||||
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
|
||||
query_states = query_states.contiguous()
|
||||
key_cache = key_cache.contiguous()
|
||||
value_cache = value_cache.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_cache,
|
||||
|
||||
Reference in New Issue
Block a user