mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
from functools import cache
|
|
2
|
+
import re
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Dict, List, Optional, Any, Literal
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from .base import (
|
|
10
|
+
BaseModelArgs,
|
|
11
|
+
last_token_pooling,
|
|
12
|
+
mean_pooling,
|
|
13
|
+
normalize_embeddings,
|
|
14
|
+
compute_similarity_and_loss,
|
|
15
|
+
RaclateBaseModel,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
Not using cache in this implementation given
|
|
20
|
+
the model is intended to be used for embedding and classification tasks.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ModelArgs(BaseModelArgs):
|
|
25
|
+
architectures: List[str] = field(default_factory=lambda: ["Lfm2Model"])
|
|
26
|
+
block_auto_adjust_ff_dim: bool = False
|
|
27
|
+
block_dim: int = 1024
|
|
28
|
+
block_ff_dim: int = 6656
|
|
29
|
+
block_ffn_dim_multiplier: float = 1.0
|
|
30
|
+
block_mlp_init_scale: Optional[float] = None
|
|
31
|
+
block_multiple_of: int = 256
|
|
32
|
+
block_norm_eps: float = 1e-5 # where to use this?
|
|
33
|
+
block_use_swiglu: bool = True # where to use this?
|
|
34
|
+
block_use_xavier_init: bool = True # where to use this?
|
|
35
|
+
bos_token_id: int = 1
|
|
36
|
+
conv_bias: bool = False
|
|
37
|
+
conv_L_cache: int = 3
|
|
38
|
+
conv_dim : int = 1024 # where to use this?
|
|
39
|
+
conv_dim_out : int = 1024 # where to use this?
|
|
40
|
+
conv_use_xavier_init: bool = True # where to use this?
|
|
41
|
+
eos_token_id: int = 7
|
|
42
|
+
full_attn_idxs: Optional[List[int]] = None
|
|
43
|
+
hidden_size: int = 1024
|
|
44
|
+
initializer_range: Optional[float] = (
|
|
45
|
+
0.02 # Only needed in case of initializing weights
|
|
46
|
+
)
|
|
47
|
+
layer_types: Optional[List[str]] = None
|
|
48
|
+
max_position_embeddings: int = 128000
|
|
49
|
+
model_type: str = "lfm2"
|
|
50
|
+
norm_eps: float = 1e-05
|
|
51
|
+
num_attention_heads: int = 16
|
|
52
|
+
num_hidden_layers: int = 16
|
|
53
|
+
num_key_value_heads: int = 8
|
|
54
|
+
out_features: int = 128 # classifier output features
|
|
55
|
+
pad_token_id: int = 0
|
|
56
|
+
rope_theta: float = 1000000.0
|
|
57
|
+
vocab_size: int = 65536
|
|
58
|
+
|
|
59
|
+
### pipeline args
|
|
60
|
+
decoder_bias=True,
|
|
61
|
+
classifier_dropout=0.0
|
|
62
|
+
classifier_bias=False
|
|
63
|
+
sparse_prediction=True ### True seems a more appropriate value for MLM
|
|
64
|
+
sparse_pred_ignore_index=-100
|
|
65
|
+
is_regression: Optional[bool] = None
|
|
66
|
+
label2id: Optional[Dict[str, int]] = None
|
|
67
|
+
id2label: Optional[Dict[int, str]] = None
|
|
68
|
+
pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
|
|
69
|
+
use_late_interaction: bool = False
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def num_labels(self) -> int:
|
|
73
|
+
"""
|
|
74
|
+
Number of labels is determined by:
|
|
75
|
+
- For zero-shot classification: length of label_candidates
|
|
76
|
+
- For regression or binary with sigmoid: 1
|
|
77
|
+
- For classification: length of id2label mapping
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
if self.is_regression:
|
|
81
|
+
return 1
|
|
82
|
+
|
|
83
|
+
if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
|
|
84
|
+
return 1
|
|
85
|
+
|
|
86
|
+
if self.id2label is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"id2label mapping must be provided for categorical classification. "
|
|
89
|
+
"For regression or binary classification with sigmoid output, "
|
|
90
|
+
"set is_regression=True or binary_sigmoid=True in pipeline_config."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return len(self.id2label)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
|
|
97
|
+
"""
|
|
98
|
+
Standardizes keys for the Gemma3 Backbone.
|
|
99
|
+
Prefixes generic keys with 'model.' and handles basic mapping.
|
|
100
|
+
"""
|
|
101
|
+
sanitized = {}
|
|
102
|
+
for k, v in weights.items():
|
|
103
|
+
# Skip unrelated heads that might be in the checkpoint
|
|
104
|
+
if any(x in k for x in ["lm_head", "classifier"]):
|
|
105
|
+
# We don't automatically map these; specific models handle them if needed
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
if "position_ids" in k:
|
|
109
|
+
# Remove unused position_ids
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
if "conv.weight" in k:
|
|
113
|
+
if v.shape[-1] > v.shape[1]:
|
|
114
|
+
v = v.transpose(0, 2, 1)
|
|
115
|
+
|
|
116
|
+
# Handle potential non-prefixed weights
|
|
117
|
+
# not prefixing "\d+_Dense\.linear" enables futher processing in ModelForSentenceTransformer
|
|
118
|
+
if "Dense.linear" not in k and \
|
|
119
|
+
not k.startswith("model.") and \
|
|
120
|
+
not k.startswith("dense.") and \
|
|
121
|
+
not k.startswith("score.") :
|
|
122
|
+
|
|
123
|
+
new_key = f"model.{k}"
|
|
124
|
+
|
|
125
|
+
sanitized[new_key] = v
|
|
126
|
+
else:
|
|
127
|
+
sanitized[k] = v
|
|
128
|
+
|
|
129
|
+
return sanitized
|
|
130
|
+
|
|
131
|
+
class Attention(nn.Module):
|
|
132
|
+
def __init__(self, args: ModelArgs):
|
|
133
|
+
super().__init__()
|
|
134
|
+
|
|
135
|
+
dim = args.hidden_size
|
|
136
|
+
self.n_heads = n_heads = args.num_attention_heads
|
|
137
|
+
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
|
138
|
+
|
|
139
|
+
self.head_dim = head_dim = args.hidden_size // n_heads
|
|
140
|
+
|
|
141
|
+
self.scale = head_dim**-0.5
|
|
142
|
+
|
|
143
|
+
self.q_layernorm = nn.RMSNorm(head_dim, eps=args.norm_eps)
|
|
144
|
+
self.k_layernorm = nn.RMSNorm(head_dim, eps=args.norm_eps)
|
|
145
|
+
|
|
146
|
+
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
|
|
147
|
+
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
148
|
+
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
|
149
|
+
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
|
150
|
+
|
|
151
|
+
self.rope = nn.RoPE(
|
|
152
|
+
self.head_dim,
|
|
153
|
+
base=args.rope_theta,
|
|
154
|
+
traditional=False,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def __call__(
|
|
158
|
+
self,
|
|
159
|
+
x: mx.array,
|
|
160
|
+
mask: Optional[mx.array] = None,
|
|
161
|
+
cache: Optional[Any] = None
|
|
162
|
+
) -> mx.array:
|
|
163
|
+
B, L, D = x.shape
|
|
164
|
+
|
|
165
|
+
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
166
|
+
|
|
167
|
+
queries = self.q_layernorm(queries.reshape(B, L, self.n_heads, -1)).transpose(
|
|
168
|
+
0, 2, 1, 3
|
|
169
|
+
)
|
|
170
|
+
keys = self.k_layernorm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(
|
|
171
|
+
0, 2, 1, 3
|
|
172
|
+
)
|
|
173
|
+
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
|
174
|
+
|
|
175
|
+
queries = self.rope(queries)
|
|
176
|
+
keys = self.rope(keys)
|
|
177
|
+
|
|
178
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
179
|
+
queries, keys, values, scale=self.scale, mask=mask
|
|
180
|
+
)
|
|
181
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
182
|
+
return self.out_proj(output)
|
|
183
|
+
|
|
184
|
+
class ShortConv(nn.Module):
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
args: ModelArgs,
|
|
188
|
+
layer_idx: int,
|
|
189
|
+
):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.args = args
|
|
192
|
+
self.layer_idx = layer_idx
|
|
193
|
+
self.L_cache = args.conv_L_cache
|
|
194
|
+
self.bias = args.conv_bias
|
|
195
|
+
|
|
196
|
+
self.conv = nn.Conv1d(
|
|
197
|
+
in_channels=args.hidden_size,
|
|
198
|
+
out_channels=args.hidden_size,
|
|
199
|
+
kernel_size=self.L_cache,
|
|
200
|
+
groups=args.hidden_size,
|
|
201
|
+
bias=self.bias,
|
|
202
|
+
)
|
|
203
|
+
self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=self.bias)
|
|
204
|
+
self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=self.bias)
|
|
205
|
+
|
|
206
|
+
def __call__(
|
|
207
|
+
self,
|
|
208
|
+
x: mx.array,
|
|
209
|
+
mask: Optional[mx.array] = None,
|
|
210
|
+
cache: Optional[Any] = None
|
|
211
|
+
):
|
|
212
|
+
BCx = self.in_proj(x)
|
|
213
|
+
B, C, x = mx.split(BCx, 3, axis=-1)
|
|
214
|
+
Bx = B * x
|
|
215
|
+
if mask is not None:
|
|
216
|
+
Bx = mx.where(mask[..., None], Bx, 0)
|
|
217
|
+
|
|
218
|
+
state = mx.zeros(
|
|
219
|
+
(Bx.shape[0], self.L_cache - 1, self.args.hidden_size), dtype=Bx.dtype
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
Bx = mx.concatenate([state, Bx], axis=-2)
|
|
223
|
+
conv_out = self.conv(Bx)
|
|
224
|
+
|
|
225
|
+
y = C * conv_out
|
|
226
|
+
return self.out_proj(y)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class MLP(nn.Module):
|
|
230
|
+
def __init__(
|
|
231
|
+
self,
|
|
232
|
+
dim: int,
|
|
233
|
+
ff_dim: int,
|
|
234
|
+
multiple_of: int,
|
|
235
|
+
auto_adjust_ff_dim: bool,
|
|
236
|
+
ffn_dim_multiplier: Optional[float],
|
|
237
|
+
):
|
|
238
|
+
super().__init__()
|
|
239
|
+
if auto_adjust_ff_dim:
|
|
240
|
+
ff_dim = int(2 * ff_dim / 3)
|
|
241
|
+
if ffn_dim_multiplier is not None:
|
|
242
|
+
ff_dim = int(ffn_dim_multiplier * ff_dim)
|
|
243
|
+
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
|
244
|
+
|
|
245
|
+
self.w1 = nn.Linear(dim, ff_dim, bias=False)
|
|
246
|
+
self.w3 = nn.Linear(dim, ff_dim, bias=False)
|
|
247
|
+
self.w2 = nn.Linear(ff_dim, dim, bias=False)
|
|
248
|
+
|
|
249
|
+
def __call__(self, x) -> mx.array:
|
|
250
|
+
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class Lfm2DecoderLayer(nn.Module):
|
|
254
|
+
def __init__(self, args: ModelArgs, layer_idx: int):
|
|
255
|
+
super().__init__()
|
|
256
|
+
if args.full_attn_idxs :
|
|
257
|
+
self.is_attention_layer = layer_idx in args.full_attn_idxs
|
|
258
|
+
elif args.layer_types:
|
|
259
|
+
self.is_attention_layer = args.layer_types[layer_idx] == "full_attention"
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError("Either full_attn_idxs or layer_types must be provided in ModelArgs")
|
|
262
|
+
|
|
263
|
+
if self.is_attention_layer:
|
|
264
|
+
self.self_attn = Attention(args)
|
|
265
|
+
else:
|
|
266
|
+
self.conv = ShortConv(args, layer_idx)
|
|
267
|
+
|
|
268
|
+
self.feed_forward = MLP(
|
|
269
|
+
dim=args.block_dim,
|
|
270
|
+
ff_dim=args.block_ff_dim,
|
|
271
|
+
multiple_of=args.block_multiple_of,
|
|
272
|
+
auto_adjust_ff_dim=args.block_auto_adjust_ff_dim,
|
|
273
|
+
ffn_dim_multiplier=args.block_ffn_dim_multiplier,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
|
|
277
|
+
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
|
|
278
|
+
|
|
279
|
+
def __call__(
|
|
280
|
+
self,
|
|
281
|
+
x: mx.array,
|
|
282
|
+
mask: Optional[mx.array] = None,
|
|
283
|
+
cache: Optional[Any] = None,
|
|
284
|
+
) -> mx.array:
|
|
285
|
+
|
|
286
|
+
if self.is_attention_layer:
|
|
287
|
+
r = self.self_attn(self.operator_norm(x), mask=mask, cache=cache)
|
|
288
|
+
else:
|
|
289
|
+
r = self.conv(
|
|
290
|
+
self.operator_norm(x),
|
|
291
|
+
mask=mask,
|
|
292
|
+
cache=cache,
|
|
293
|
+
)
|
|
294
|
+
h = x + r
|
|
295
|
+
out = h + self.feed_forward(self.ffn_norm(h))
|
|
296
|
+
return (out,)
|
|
297
|
+
|
|
298
|
+
class Lfm2Model(nn.Module):
|
|
299
|
+
def __init__(self, args: ModelArgs):
|
|
300
|
+
super().__init__()
|
|
301
|
+
self.args = args
|
|
302
|
+
self.vocab_size = args.vocab_size
|
|
303
|
+
self.num_hidden_layers = args.num_hidden_layers
|
|
304
|
+
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
305
|
+
self.layers = [
|
|
306
|
+
Lfm2DecoderLayer(args, layer_idx=i) for i in range(args.num_hidden_layers)
|
|
307
|
+
]
|
|
308
|
+
|
|
309
|
+
self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
|
|
310
|
+
|
|
311
|
+
self.conv_idx = 0
|
|
312
|
+
if args.full_attn_idxs:
|
|
313
|
+
for i in range(args.num_hidden_layers):
|
|
314
|
+
if i in args.full_attn_idxs:
|
|
315
|
+
self.conv_idx += 1
|
|
316
|
+
else:
|
|
317
|
+
break
|
|
318
|
+
elif args.layer_types:
|
|
319
|
+
for i in range(args.num_hidden_layers):
|
|
320
|
+
if args.layer_types[i] != "full_attention":
|
|
321
|
+
self.conv_idx += 1
|
|
322
|
+
else:
|
|
323
|
+
break
|
|
324
|
+
else:
|
|
325
|
+
raise ValueError("Either full_attn_idxs or layer_types must be provided in ModelArgs")
|
|
326
|
+
|
|
327
|
+
self.hf_transformers_arch = "Lfm2Model"
|
|
328
|
+
|
|
329
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
|
330
|
+
return self.embed_tokens
|
|
331
|
+
|
|
332
|
+
def set_input_embeddings(self, value):
|
|
333
|
+
self.embed_tokens = value
|
|
334
|
+
|
|
335
|
+
def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
|
|
336
|
+
"""
|
|
337
|
+
Creates a causal mask and combines it with the padding mask.
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
B, L = attention_mask.shape
|
|
341
|
+
|
|
342
|
+
causal_mask = mx.triu(mx.full((L, L), -1e9, dtype), k=1)
|
|
343
|
+
|
|
344
|
+
if attention_mask is not None:
|
|
345
|
+
# Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
|
|
346
|
+
padding_mask = attention_mask[:, None, None, :]
|
|
347
|
+
additive_padding_mask = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
|
|
348
|
+
|
|
349
|
+
causal_mask = causal_mask + additive_padding_mask
|
|
350
|
+
|
|
351
|
+
return causal_mask.astype(dtype)
|
|
352
|
+
|
|
353
|
+
def _create_ssm_mask(self, h, cache=None):
|
|
354
|
+
if cache and hasattr(cache, "make_mask"):
|
|
355
|
+
return cache.make_mask(h.shape[1])
|
|
356
|
+
return None
|
|
357
|
+
|
|
358
|
+
def __call__(
|
|
359
|
+
self,
|
|
360
|
+
input_ids: mx.array,
|
|
361
|
+
attention_mask: Optional[mx.array] = None,
|
|
362
|
+
):
|
|
363
|
+
|
|
364
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
365
|
+
model_dtype = hidden_states.dtype
|
|
366
|
+
|
|
367
|
+
cache = [None] * len(self.layers)
|
|
368
|
+
|
|
369
|
+
attn_mask = self._update_attention_mask(attention_mask, dtype=model_dtype)
|
|
370
|
+
conv_mask = self._create_ssm_mask(hidden_states, cache[self.conv_idx])
|
|
371
|
+
|
|
372
|
+
for layer, c in zip(self.layers, cache):
|
|
373
|
+
mask = attn_mask if layer.is_attention_layer else conv_mask
|
|
374
|
+
layer_outputs = layer(hidden_states, mask, cache=c)
|
|
375
|
+
hidden_states = layer_outputs[0]
|
|
376
|
+
|
|
377
|
+
hidden_states = self.embedding_norm(hidden_states)
|
|
378
|
+
|
|
379
|
+
return {
|
|
380
|
+
"last_hidden_state": hidden_states,
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
class Model(RaclateBaseModel):
|
|
385
|
+
def __init__(self, config: ModelArgs):
|
|
386
|
+
super().__init__()
|
|
387
|
+
self.config = config
|
|
388
|
+
self.model_type = config.model_type
|
|
389
|
+
self.model = Lfm2Model(config)
|
|
390
|
+
|
|
391
|
+
def __call__(
|
|
392
|
+
self,
|
|
393
|
+
input_ids: mx.array,
|
|
394
|
+
position_ids: Optional[mx.array] = None,
|
|
395
|
+
attention_mask: Optional[mx.array] = None,
|
|
396
|
+
output_hidden_states: Optional[bool] = False,
|
|
397
|
+
return_dict: Optional[bool] = True,
|
|
398
|
+
):
|
|
399
|
+
|
|
400
|
+
if attention_mask is None:
|
|
401
|
+
batch_size, seq_len = input_ids.shape
|
|
402
|
+
attention_mask = mx.ones(
|
|
403
|
+
(batch_size, seq_len),
|
|
404
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
out = self.model(input_ids, attention_mask)
|
|
408
|
+
|
|
409
|
+
last_hidden_state = (
|
|
410
|
+
out["last_hidden_state"] if isinstance(out, dict) else out[0]
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# LFM2 is a causal model, so we use last token pooling for embeddings
|
|
414
|
+
text_embeds = last_token_pooling(last_hidden_state, attention_mask)
|
|
415
|
+
text_embeds = normalize_embeddings(text_embeds)
|
|
416
|
+
|
|
417
|
+
if not return_dict:
|
|
418
|
+
return (text_embeds, last_hidden_state)
|
|
419
|
+
|
|
420
|
+
return {
|
|
421
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
422
|
+
"last_hidden_state": last_hidden_state,
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def sanitize(self, weights):
|
|
427
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
428
|
+
sanitized = {}
|
|
429
|
+
for k, v in sanitized_weights.items():
|
|
430
|
+
if not k.startswith("model."):
|
|
431
|
+
continue
|
|
432
|
+
sanitized[k] = v
|
|
433
|
+
return sanitized
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class ModelForSentenceSimilarity(RaclateBaseModel):
|
|
437
|
+
"""
|
|
438
|
+
Computes similarity scores between input sequences and reference sentences.
|
|
439
|
+
"""
|
|
440
|
+
def __init__(self, config: ModelArgs):
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.config = config
|
|
443
|
+
self.model_type = config.model_type
|
|
444
|
+
self.model = Lfm2Model(config)
|
|
445
|
+
self.dense = [
|
|
446
|
+
nn.Linear(config.block_dim, config.out_features, bias=False),
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
def _call_model(self, input_ids, attention_mask=None, return_dict=True):
|
|
450
|
+
out = self.model(input_ids, attention_mask)
|
|
451
|
+
last_hidden_state = (
|
|
452
|
+
out["last_hidden_state"] if isinstance(out, dict) else out[0]
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
for dense in self.dense:
|
|
456
|
+
last_hidden_state = dense(last_hidden_state)
|
|
457
|
+
|
|
458
|
+
# text_embeds = normalize_embeddings(last_hidden_state)
|
|
459
|
+
if self.config.use_late_interaction:
|
|
460
|
+
text_embeds = normalize_embeddings(last_hidden_state)
|
|
461
|
+
# Keep unpooled for ColBERT style
|
|
462
|
+
# Mask padding tokens to avoid them affecting MaxSim
|
|
463
|
+
if attention_mask is not None:
|
|
464
|
+
text_embeds = text_embeds * attention_mask[..., None]
|
|
465
|
+
else:
|
|
466
|
+
# Standard dense retrieval: Mean Pooling
|
|
467
|
+
text_embeds = mean_pooling(last_hidden_state, attention_mask)
|
|
468
|
+
text_embeds = normalize_embeddings(text_embeds)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
if not return_dict:
|
|
472
|
+
return (text_embeds, last_hidden_state)
|
|
473
|
+
|
|
474
|
+
return {
|
|
475
|
+
"embeddings": text_embeds, # normalized embeddings
|
|
476
|
+
"last_hidden_state": last_hidden_state,
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
def __call__(
|
|
480
|
+
self,
|
|
481
|
+
input_ids,
|
|
482
|
+
reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
|
|
483
|
+
negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
|
|
484
|
+
attention_mask: Optional[mx.array] = None,
|
|
485
|
+
reference_attention_mask: Optional[mx.array] = None,
|
|
486
|
+
negative_attention_mask: Optional[mx.array] = None,
|
|
487
|
+
similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
|
|
488
|
+
position_ids: Optional[mx.array] = None,
|
|
489
|
+
return_dict: Optional[bool] = True,
|
|
490
|
+
):
|
|
491
|
+
if attention_mask is None:
|
|
492
|
+
batch_size, seq_len = input_ids.shape
|
|
493
|
+
attention_mask = mx.ones(
|
|
494
|
+
(batch_size, seq_len),
|
|
495
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
# Get embeddings for input batch
|
|
499
|
+
batch_outputs = self._call_model(
|
|
500
|
+
input_ids=input_ids,
|
|
501
|
+
attention_mask=attention_mask,
|
|
502
|
+
return_dict=True
|
|
503
|
+
)
|
|
504
|
+
embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
|
|
505
|
+
|
|
506
|
+
loss = None
|
|
507
|
+
similarities = None
|
|
508
|
+
if reference_input_ids is not None:
|
|
509
|
+
|
|
510
|
+
# Get embeddings for reference sentences
|
|
511
|
+
ref_outputs = self._call_model(
|
|
512
|
+
input_ids=reference_input_ids,
|
|
513
|
+
attention_mask=reference_attention_mask,
|
|
514
|
+
return_dict=True
|
|
515
|
+
)
|
|
516
|
+
reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
|
|
517
|
+
|
|
518
|
+
similarities, loss = compute_similarity_and_loss(
|
|
519
|
+
self.config,
|
|
520
|
+
input_ids,
|
|
521
|
+
embeddings,
|
|
522
|
+
reference_embeddings,
|
|
523
|
+
self._call_model,
|
|
524
|
+
similarity_scores,
|
|
525
|
+
negative_input_ids,
|
|
526
|
+
negative_attention_mask
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if not return_dict:
|
|
530
|
+
return (loss, similarities, embeddings)
|
|
531
|
+
|
|
532
|
+
return {
|
|
533
|
+
"loss": loss,
|
|
534
|
+
"similarities": similarities, # [batch_size, num_references]
|
|
535
|
+
"embeddings": embeddings, # [batch_size, hidden_size]
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
def sanitize(self, weights):
|
|
539
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
540
|
+
sanitized = {}
|
|
541
|
+
for k, v in sanitized_weights.items():
|
|
542
|
+
if not k.startswith("model.") and not k.startswith("dense."):
|
|
543
|
+
continue
|
|
544
|
+
sanitized[k] = v
|
|
545
|
+
return sanitized
|
|
546
|
+
|
|
547
|
+
class ModelForSentenceTransformers(ModelForSentenceSimilarity):
|
|
548
|
+
"""
|
|
549
|
+
Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
|
|
550
|
+
This class sanitizes typical sentence transformers weights to align with the T5Gemma model.
|
|
551
|
+
"""
|
|
552
|
+
def __init__(self, config: ModelArgs):
|
|
553
|
+
super().__init__(config)
|
|
554
|
+
|
|
555
|
+
def sanitize(self, weights):
|
|
556
|
+
"""Convert sentence transformer weights to T5Gemma format."""
|
|
557
|
+
sanitized = _sanitize_backbone(weights)
|
|
558
|
+
|
|
559
|
+
sanitized_weights = {}
|
|
560
|
+
for k, v in sanitized.items():
|
|
561
|
+
if "1_Dense.linear" in k:
|
|
562
|
+
new_key = k.replace("1_Dense.linear", "dense.0")
|
|
563
|
+
sanitized_weights[new_key] = v
|
|
564
|
+
elif k.startswith("model.") or k.startswith("dense."):
|
|
565
|
+
sanitized_weights[k] = v
|
|
566
|
+
else:
|
|
567
|
+
continue
|
|
568
|
+
return sanitized_weights
|
|
569
|
+
|
|
570
|
+
class ModelForSequenceClassification(RaclateBaseModel):
|
|
571
|
+
"""
|
|
572
|
+
Computes sequence classification probabilities for input sequences.
|
|
573
|
+
|
|
574
|
+
NOTE : regression and binary classification not tested.
|
|
575
|
+
"""
|
|
576
|
+
def __init__(self, config: ModelArgs):
|
|
577
|
+
super().__init__()
|
|
578
|
+
self.config = config
|
|
579
|
+
self.num_labels = config.num_labels
|
|
580
|
+
self.is_regression = config.is_regression
|
|
581
|
+
|
|
582
|
+
self.model = Lfm2Model(config)
|
|
583
|
+
|
|
584
|
+
# No HF transformers architecture SequenceClassification typically only as a score layer
|
|
585
|
+
self.score = nn.Linear(
|
|
586
|
+
config.hidden_size,
|
|
587
|
+
config.num_labels,
|
|
588
|
+
bias=False
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
# No HF transformers architecture for LFM2 and SequenceClassification
|
|
592
|
+
|
|
593
|
+
def _process_outputs(self, logits: mx.array) -> mx.array:
|
|
594
|
+
"""Apply the appropriate activation function to the logits."""
|
|
595
|
+
if self.is_regression:
|
|
596
|
+
return logits # No activation for regression
|
|
597
|
+
elif self.num_labels == 1:
|
|
598
|
+
return mx.sigmoid(logits) # Binary classification
|
|
599
|
+
else:
|
|
600
|
+
# Using softmax for multi-class classification
|
|
601
|
+
return mx.softmax(logits, axis=-1)
|
|
602
|
+
|
|
603
|
+
def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
|
|
604
|
+
"""Compute the appropriate loss based on label characteristics."""
|
|
605
|
+
if self.is_regression:
|
|
606
|
+
return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
|
|
607
|
+
elif self.num_labels == 1:
|
|
608
|
+
return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
|
|
609
|
+
else:
|
|
610
|
+
return nn.losses.cross_entropy(
|
|
611
|
+
logits.reshape(-1, self.num_labels),
|
|
612
|
+
labels.reshape(-1)
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
def __call__(
|
|
616
|
+
self,
|
|
617
|
+
input_ids,
|
|
618
|
+
attention_mask: Optional[mx.array] = None,
|
|
619
|
+
position_ids: Optional[mx.array] = None, ### need this?
|
|
620
|
+
labels: Optional[mx.array] = None,
|
|
621
|
+
output_hidden_states: Optional[bool] = False,
|
|
622
|
+
return_dict: Optional[bool] = True,
|
|
623
|
+
) -> Dict:
|
|
624
|
+
if attention_mask is None:
|
|
625
|
+
batch_size, seq_len = input_ids.shape
|
|
626
|
+
attention_mask = mx.ones(
|
|
627
|
+
(batch_size, seq_len),
|
|
628
|
+
dtype=self.model.embed_tokens.weight.dtype,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
outputs = self.model(
|
|
632
|
+
input_ids,
|
|
633
|
+
attention_mask
|
|
634
|
+
)
|
|
635
|
+
last_hidden_state = (
|
|
636
|
+
outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
# pooling for AR models such as LFM2 leverages the last token
|
|
640
|
+
pooled = last_token_pooling(last_hidden_state, attention_mask)
|
|
641
|
+
|
|
642
|
+
### The HF architecture for SequenceClassification typically only has a score layer
|
|
643
|
+
logits = self.score(pooled)
|
|
644
|
+
|
|
645
|
+
processed_logits = self._process_outputs(logits)
|
|
646
|
+
|
|
647
|
+
loss = None
|
|
648
|
+
if labels is not None :
|
|
649
|
+
loss = self._compute_loss(logits, labels)
|
|
650
|
+
|
|
651
|
+
if not return_dict:
|
|
652
|
+
return [loss, processed_logits, outputs[1:]]
|
|
653
|
+
|
|
654
|
+
return {
|
|
655
|
+
"loss": loss,
|
|
656
|
+
"probabilities": processed_logits,
|
|
657
|
+
"hidden_states": outputs.get("hidden_states", None),
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
def sanitize(self, weights):
|
|
661
|
+
sanitized_weights = _sanitize_backbone(weights)
|
|
662
|
+
sanitized = {}
|
|
663
|
+
for k, v in sanitized_weights.items():
|
|
664
|
+
if not k.startswith("model.") and not k.startswith("score."):
|
|
665
|
+
continue
|
|
666
|
+
sanitized[k] = v
|
|
667
|
+
return sanitized
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
# TokenClassification and MaskedLM not implemented for now AR models such as LFM2
|
|
671
|
+
# Attempting to train pretrained weights would be catastrophic
|