FX: Add MPS support

This commit is contained in:
刘鑫
2025-09-28 21:06:35 +08:00
parent fbf8984d4e
commit 2eb4d39719
2 changed files with 22 additions and 8 deletions

View File

@@ -153,7 +153,12 @@ class MiniCPMAttention(nn.Module):
cos, sin = position_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@@ -198,6 +203,11 @@ class MiniCPMAttention(nn.Module):
attn_mask = torch.arange(key_cache.size(2), device=key_cache.device) <= position_id
# ref: https://github.com/pytorch/pytorch/issues/163597
# there is a bug in MPS for non-contiguous tensors, so we need to make them contiguous
query_states = query_states.contiguous()
key_cache = key_cache.contiguous()
value_cache = value_cache.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_cache,