mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+
File without changes
@@ -0,0 +1,225 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from typing import Optional
6
+
7
+ class RaclateBaseModel(nn.Module):
8
+ """Base class for Raclate models."""
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def get_hf_transformers_arch(self):
13
+ return self.hf_transformers_arch if hasattr(self, "hf_transformers_arch") else None
14
+
15
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
16
+ """
17
+ Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
18
+ """
19
+ old_embeddings = self.get_input_embeddings()
20
+ if old_embeddings is None:
21
+ raise ValueError("Model does not support get_input_embeddings")
22
+
23
+ if new_num_tokens is None:
24
+ return old_embeddings
25
+
26
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
27
+
28
+ if new_num_tokens == old_num_tokens:
29
+ return old_embeddings
30
+
31
+ # Create new embeddings
32
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
33
+
34
+ # Initialize new weights (e.g. Normal)
35
+ new_embeddings.weight = mx.random.normal(shape=(new_num_tokens, old_embedding_dim)) * 0.02
36
+
37
+ # We copy up to the min size to handle both expansion and shrinking (though usually expansion)
38
+ n = min(old_num_tokens, new_num_tokens)
39
+
40
+ # Combine old relevant weights with new random weights for the extension
41
+ combined_weight = mx.concatenate([
42
+ old_embeddings.weight[:n],
43
+ new_embeddings.weight[n:]
44
+ ], axis=0) if new_num_tokens > old_num_tokens else old_embeddings.weight[:n]
45
+
46
+ new_embeddings.weight = combined_weight
47
+
48
+ self.set_input_embeddings(new_embeddings)
49
+
50
+ # Update config if present
51
+ if hasattr(self, "config"):
52
+ self.config.vocab_size = new_num_tokens
53
+
54
+ return new_embeddings
55
+
56
+ @dataclass
57
+ class BaseModelArgs:
58
+ @classmethod
59
+ def from_dict(cls, params):
60
+ return cls(
61
+ **{
62
+ k: v
63
+ for k, v in params.items()
64
+ if k in inspect.signature(cls).parameters
65
+ }
66
+ )
67
+
68
+ def compute_similarity(query_embeddings: mx.array, reference_embeddings: mx.array) -> mx.array:
69
+ """Computes cosine similarity between query embeddings and reference embeddings.
70
+
71
+ Args:
72
+ query_embeddings: Shape [batch_size, hidden_size]
73
+ These are the embeddings we want to classify/compare - already normalized
74
+ reference_embeddings: Shape [num_references, hidden_size]
75
+ These are our label descriptions or comparison sentences - already normalized
76
+
77
+ Returns:
78
+ Similarity matrix of shape [batch_size, num_references]
79
+ Each row contains similarities between one query and all references
80
+ """
81
+ # Compute similarities - results in [batch_size, num_references]
82
+ # Each row contains similarities between one input and all references
83
+ similarities = mx.matmul(query_embeddings, reference_embeddings.T)
84
+
85
+ return similarities
86
+
87
+ def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
88
+ input_mask_expanded = mx.expand_dims(attention_mask, -1)
89
+ input_mask_expanded = mx.broadcast_to(
90
+ input_mask_expanded, token_embeddings.shape
91
+ ).astype(mx.float32)
92
+
93
+ sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
94
+ sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
95
+
96
+ return sum_embeddings / sum_mask
97
+
98
+
99
+ def last_token_pooling(
100
+ last_hidden_states: mx.array, attention_mask: Optional[mx.array] = None
101
+ ) -> mx.array:
102
+ """
103
+ Last token pooling, compatible with MLX compilation/grad
104
+
105
+ Args:
106
+ last_hidden_states: Hidden states from the model, shape (batch_size, seq_len, hidden_size)
107
+ attention_mask: Attention mask, shape (batch_size, seq_len). If None, uses last position.
108
+
109
+ Returns:
110
+ Pooled embeddings, shape (batch_size, hidden_size)
111
+ """
112
+ if attention_mask is None:
113
+ return last_hidden_states[:, -1]
114
+
115
+ B, S, _ = last_hidden_states.shape
116
+ indices = mx.arange(S)
117
+
118
+ # Only keep the unpadded tokens
119
+ masked_indices = indices * attention_mask.astype(indices.dtype)
120
+
121
+ # Find the last valid index (max index) for each batch item
122
+ last_token_indices = masked_indices.max(axis=1)
123
+
124
+ batch_indices = mx.arange(B)
125
+
126
+ # Select specific [batch, token] pairs
127
+ return last_hidden_states[batch_indices, last_token_indices]
128
+
129
+
130
+ def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9):
131
+ return embeddings / mx.maximum(
132
+ mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps
133
+ )
134
+
135
+ def compute_late_interaction_scores(Q, D):
136
+ """
137
+ MaxSim: sum_i(max_j(Q_i . D_j))
138
+ Args:
139
+ Q: Query embeddings [B_q, L_q, Dim]
140
+ D: Doc embeddings [B_d, L_d, Dim]
141
+ Note: If calculating loss with in-batch negatives, shapes might vary.
142
+ """
143
+ # (B, L_q, Dim) @ (B, Dim, L_d) -> (B, L_q, L_d)
144
+ # This assumes pairwise (Query[i] vs Doc[i]).
145
+ sim_matrix = Q @ D.transpose(0, 2, 1)
146
+ max_scores = mx.max(sim_matrix, axis=-1) # (B, L_q)
147
+ return mx.sum(max_scores, axis=-1) # (B,)
148
+
149
+
150
+ def compute_similarity_and_loss(
151
+ config,
152
+ input_ids: mx.array,
153
+ embeddings: mx.array,
154
+ reference_embeddings: mx.array,
155
+ call_model : callable,
156
+ similarity_scores: Optional[mx.array],
157
+ negative_input_ids: Optional[mx.array] = None,
158
+ negative_attention_mask: Optional[mx.array] = None
159
+ ):
160
+ # MSE loss between computed similarities and target scores
161
+ if similarity_scores is not None:
162
+ assert reference_embeddings.shape[0] == input_ids.shape[0], "Number of references must match batch size for paired training"
163
+ assert similarity_scores.shape[0] == input_ids.shape[0], "Number of similarity scores must match batch size for paired training"
164
+ if config.use_late_interaction:
165
+ pairwise_sims = compute_late_interaction_scores(embeddings, reference_embeddings)
166
+ else:
167
+ # No matmul here, we only care about Query i vs Ref i
168
+ pairwise_sims = mx.sum(embeddings * reference_embeddings, axis=-1)
169
+
170
+ # Ensure scores match shape
171
+ if len(similarity_scores.shape) > 1:
172
+ similarity_scores = similarity_scores.flatten()
173
+
174
+ loss = nn.losses.mse_loss(pairwise_sims, similarity_scores)
175
+ similarities = pairwise_sims
176
+
177
+ # Cross-entropy loss [for triplet training with hard negatives]
178
+ else:
179
+ if config.use_late_interaction:
180
+ # Q: [B, L, D], C: [2B, L, D] (if negatives exist)
181
+ if negative_input_ids is not None:
182
+ neg_outputs = call_model(negative_input_ids, negative_attention_mask)
183
+ neg_embeddings = neg_outputs["embeddings"]
184
+ candidates = mx.concatenate([reference_embeddings, neg_embeddings], axis=0)
185
+ else:
186
+ candidates = reference_embeddings
187
+
188
+ # Manual Broadcasting for Late Interaction cross-batch
189
+ Q_broad = embeddings[:, None, :, :]
190
+ C_broad = candidates[None, :, :, :].transpose(0, 1, 3, 2)
191
+
192
+ sim_matrix = Q_broad @ C_broad
193
+
194
+ # Max over Doc length, Sum over Query length
195
+ scores = mx.sum(mx.max(sim_matrix, axis=-1), axis=-1)
196
+ similarities = scores # [B, C]
197
+
198
+ else:
199
+ if negative_input_ids is not None:
200
+ assert reference_embeddings.shape[0] == input_ids.shape[0], "Number of references must match batch size for paired training"
201
+ assert negative_input_ids.shape[0] == input_ids.shape[0], "Number of negatives must match batch size for triplet training"
202
+ # Embed Negative
203
+ neg_outputs = call_model(
204
+ input_ids=negative_input_ids,
205
+ attention_mask=negative_attention_mask,
206
+ return_dict=True
207
+ )
208
+ neg_embeddings = neg_outputs["embeddings"]
209
+
210
+ # Stack Candidates: [Positives, Negatives]
211
+ candidates = mx.concatenate([reference_embeddings, neg_embeddings], axis=0) # Shape: [2 * batch, hidden]
212
+
213
+ else:
214
+ candidates = reference_embeddings
215
+
216
+ similarities = compute_similarity(embeddings, candidates)
217
+
218
+ scale = 20.0
219
+ scores = similarities * scale
220
+
221
+ labels = mx.arange(embeddings.shape[0])
222
+
223
+ loss = nn.losses.cross_entropy(scores, labels)
224
+
225
+ return similarities, loss