mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
mlx_raclate/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
File without changes
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
class RaclateBaseModel(nn.Module):
|
|
8
|
+
"""Base class for Raclate models."""
|
|
9
|
+
def __init__(self):
|
|
10
|
+
super().__init__()
|
|
11
|
+
|
|
12
|
+
def get_hf_transformers_arch(self):
|
|
13
|
+
return self.hf_transformers_arch if hasattr(self, "hf_transformers_arch") else None
|
|
14
|
+
|
|
15
|
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
|
16
|
+
"""
|
|
17
|
+
Resizes input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
|
18
|
+
"""
|
|
19
|
+
old_embeddings = self.get_input_embeddings()
|
|
20
|
+
if old_embeddings is None:
|
|
21
|
+
raise ValueError("Model does not support get_input_embeddings")
|
|
22
|
+
|
|
23
|
+
if new_num_tokens is None:
|
|
24
|
+
return old_embeddings
|
|
25
|
+
|
|
26
|
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
|
|
27
|
+
|
|
28
|
+
if new_num_tokens == old_num_tokens:
|
|
29
|
+
return old_embeddings
|
|
30
|
+
|
|
31
|
+
# Create new embeddings
|
|
32
|
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
|
33
|
+
|
|
34
|
+
# Initialize new weights (e.g. Normal)
|
|
35
|
+
new_embeddings.weight = mx.random.normal(shape=(new_num_tokens, old_embedding_dim)) * 0.02
|
|
36
|
+
|
|
37
|
+
# We copy up to the min size to handle both expansion and shrinking (though usually expansion)
|
|
38
|
+
n = min(old_num_tokens, new_num_tokens)
|
|
39
|
+
|
|
40
|
+
# Combine old relevant weights with new random weights for the extension
|
|
41
|
+
combined_weight = mx.concatenate([
|
|
42
|
+
old_embeddings.weight[:n],
|
|
43
|
+
new_embeddings.weight[n:]
|
|
44
|
+
], axis=0) if new_num_tokens > old_num_tokens else old_embeddings.weight[:n]
|
|
45
|
+
|
|
46
|
+
new_embeddings.weight = combined_weight
|
|
47
|
+
|
|
48
|
+
self.set_input_embeddings(new_embeddings)
|
|
49
|
+
|
|
50
|
+
# Update config if present
|
|
51
|
+
if hasattr(self, "config"):
|
|
52
|
+
self.config.vocab_size = new_num_tokens
|
|
53
|
+
|
|
54
|
+
return new_embeddings
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class BaseModelArgs:
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, params):
|
|
60
|
+
return cls(
|
|
61
|
+
**{
|
|
62
|
+
k: v
|
|
63
|
+
for k, v in params.items()
|
|
64
|
+
if k in inspect.signature(cls).parameters
|
|
65
|
+
}
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def compute_similarity(query_embeddings: mx.array, reference_embeddings: mx.array) -> mx.array:
|
|
69
|
+
"""Computes cosine similarity between query embeddings and reference embeddings.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
query_embeddings: Shape [batch_size, hidden_size]
|
|
73
|
+
These are the embeddings we want to classify/compare - already normalized
|
|
74
|
+
reference_embeddings: Shape [num_references, hidden_size]
|
|
75
|
+
These are our label descriptions or comparison sentences - already normalized
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Similarity matrix of shape [batch_size, num_references]
|
|
79
|
+
Each row contains similarities between one query and all references
|
|
80
|
+
"""
|
|
81
|
+
# Compute similarities - results in [batch_size, num_references]
|
|
82
|
+
# Each row contains similarities between one input and all references
|
|
83
|
+
similarities = mx.matmul(query_embeddings, reference_embeddings.T)
|
|
84
|
+
|
|
85
|
+
return similarities
|
|
86
|
+
|
|
87
|
+
def mean_pooling(token_embeddings: mx.array, attention_mask: mx.array):
|
|
88
|
+
input_mask_expanded = mx.expand_dims(attention_mask, -1)
|
|
89
|
+
input_mask_expanded = mx.broadcast_to(
|
|
90
|
+
input_mask_expanded, token_embeddings.shape
|
|
91
|
+
).astype(mx.float32)
|
|
92
|
+
|
|
93
|
+
sum_embeddings = mx.sum(token_embeddings * input_mask_expanded, axis=1)
|
|
94
|
+
sum_mask = mx.maximum(mx.sum(input_mask_expanded, axis=1), 1e-9)
|
|
95
|
+
|
|
96
|
+
return sum_embeddings / sum_mask
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def last_token_pooling(
|
|
100
|
+
last_hidden_states: mx.array, attention_mask: Optional[mx.array] = None
|
|
101
|
+
) -> mx.array:
|
|
102
|
+
"""
|
|
103
|
+
Last token pooling, compatible with MLX compilation/grad
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
last_hidden_states: Hidden states from the model, shape (batch_size, seq_len, hidden_size)
|
|
107
|
+
attention_mask: Attention mask, shape (batch_size, seq_len). If None, uses last position.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Pooled embeddings, shape (batch_size, hidden_size)
|
|
111
|
+
"""
|
|
112
|
+
if attention_mask is None:
|
|
113
|
+
return last_hidden_states[:, -1]
|
|
114
|
+
|
|
115
|
+
B, S, _ = last_hidden_states.shape
|
|
116
|
+
indices = mx.arange(S)
|
|
117
|
+
|
|
118
|
+
# Only keep the unpadded tokens
|
|
119
|
+
masked_indices = indices * attention_mask.astype(indices.dtype)
|
|
120
|
+
|
|
121
|
+
# Find the last valid index (max index) for each batch item
|
|
122
|
+
last_token_indices = masked_indices.max(axis=1)
|
|
123
|
+
|
|
124
|
+
batch_indices = mx.arange(B)
|
|
125
|
+
|
|
126
|
+
# Select specific [batch, token] pairs
|
|
127
|
+
return last_hidden_states[batch_indices, last_token_indices]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def normalize_embeddings(embeddings, p=2, axis=-1, keepdims=True, eps=1e-9):
|
|
131
|
+
return embeddings / mx.maximum(
|
|
132
|
+
mx.linalg.norm(embeddings, ord=p, axis=axis, keepdims=keepdims), eps
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def compute_late_interaction_scores(Q, D):
|
|
136
|
+
"""
|
|
137
|
+
MaxSim: sum_i(max_j(Q_i . D_j))
|
|
138
|
+
Args:
|
|
139
|
+
Q: Query embeddings [B_q, L_q, Dim]
|
|
140
|
+
D: Doc embeddings [B_d, L_d, Dim]
|
|
141
|
+
Note: If calculating loss with in-batch negatives, shapes might vary.
|
|
142
|
+
"""
|
|
143
|
+
# (B, L_q, Dim) @ (B, Dim, L_d) -> (B, L_q, L_d)
|
|
144
|
+
# This assumes pairwise (Query[i] vs Doc[i]).
|
|
145
|
+
sim_matrix = Q @ D.transpose(0, 2, 1)
|
|
146
|
+
max_scores = mx.max(sim_matrix, axis=-1) # (B, L_q)
|
|
147
|
+
return mx.sum(max_scores, axis=-1) # (B,)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def compute_similarity_and_loss(
|
|
151
|
+
config,
|
|
152
|
+
input_ids: mx.array,
|
|
153
|
+
embeddings: mx.array,
|
|
154
|
+
reference_embeddings: mx.array,
|
|
155
|
+
call_model : callable,
|
|
156
|
+
similarity_scores: Optional[mx.array],
|
|
157
|
+
negative_input_ids: Optional[mx.array] = None,
|
|
158
|
+
negative_attention_mask: Optional[mx.array] = None
|
|
159
|
+
):
|
|
160
|
+
# MSE loss between computed similarities and target scores
|
|
161
|
+
if similarity_scores is not None:
|
|
162
|
+
assert reference_embeddings.shape[0] == input_ids.shape[0], "Number of references must match batch size for paired training"
|
|
163
|
+
assert similarity_scores.shape[0] == input_ids.shape[0], "Number of similarity scores must match batch size for paired training"
|
|
164
|
+
if config.use_late_interaction:
|
|
165
|
+
pairwise_sims = compute_late_interaction_scores(embeddings, reference_embeddings)
|
|
166
|
+
else:
|
|
167
|
+
# No matmul here, we only care about Query i vs Ref i
|
|
168
|
+
pairwise_sims = mx.sum(embeddings * reference_embeddings, axis=-1)
|
|
169
|
+
|
|
170
|
+
# Ensure scores match shape
|
|
171
|
+
if len(similarity_scores.shape) > 1:
|
|
172
|
+
similarity_scores = similarity_scores.flatten()
|
|
173
|
+
|
|
174
|
+
loss = nn.losses.mse_loss(pairwise_sims, similarity_scores)
|
|
175
|
+
similarities = pairwise_sims
|
|
176
|
+
|
|
177
|
+
# Cross-entropy loss [for triplet training with hard negatives]
|
|
178
|
+
else:
|
|
179
|
+
if config.use_late_interaction:
|
|
180
|
+
# Q: [B, L, D], C: [2B, L, D] (if negatives exist)
|
|
181
|
+
if negative_input_ids is not None:
|
|
182
|
+
neg_outputs = call_model(negative_input_ids, negative_attention_mask)
|
|
183
|
+
neg_embeddings = neg_outputs["embeddings"]
|
|
184
|
+
candidates = mx.concatenate([reference_embeddings, neg_embeddings], axis=0)
|
|
185
|
+
else:
|
|
186
|
+
candidates = reference_embeddings
|
|
187
|
+
|
|
188
|
+
# Manual Broadcasting for Late Interaction cross-batch
|
|
189
|
+
Q_broad = embeddings[:, None, :, :]
|
|
190
|
+
C_broad = candidates[None, :, :, :].transpose(0, 1, 3, 2)
|
|
191
|
+
|
|
192
|
+
sim_matrix = Q_broad @ C_broad
|
|
193
|
+
|
|
194
|
+
# Max over Doc length, Sum over Query length
|
|
195
|
+
scores = mx.sum(mx.max(sim_matrix, axis=-1), axis=-1)
|
|
196
|
+
similarities = scores # [B, C]
|
|
197
|
+
|
|
198
|
+
else:
|
|
199
|
+
if negative_input_ids is not None:
|
|
200
|
+
assert reference_embeddings.shape[0] == input_ids.shape[0], "Number of references must match batch size for paired training"
|
|
201
|
+
assert negative_input_ids.shape[0] == input_ids.shape[0], "Number of negatives must match batch size for triplet training"
|
|
202
|
+
# Embed Negative
|
|
203
|
+
neg_outputs = call_model(
|
|
204
|
+
input_ids=negative_input_ids,
|
|
205
|
+
attention_mask=negative_attention_mask,
|
|
206
|
+
return_dict=True
|
|
207
|
+
)
|
|
208
|
+
neg_embeddings = neg_outputs["embeddings"]
|
|
209
|
+
|
|
210
|
+
# Stack Candidates: [Positives, Negatives]
|
|
211
|
+
candidates = mx.concatenate([reference_embeddings, neg_embeddings], axis=0) # Shape: [2 * batch, hidden]
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
candidates = reference_embeddings
|
|
215
|
+
|
|
216
|
+
similarities = compute_similarity(embeddings, candidates)
|
|
217
|
+
|
|
218
|
+
scale = 20.0
|
|
219
|
+
scores = similarities * scale
|
|
220
|
+
|
|
221
|
+
labels = mx.arange(embeddings.shape[0])
|
|
222
|
+
|
|
223
|
+
loss = nn.losses.cross_entropy(scores, labels)
|
|
224
|
+
|
|
225
|
+
return similarities, loss
|