mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,671 @@
1
+ from functools import cache
2
+ import re
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional, Any, Literal
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from .base import (
10
+ BaseModelArgs,
11
+ last_token_pooling,
12
+ mean_pooling,
13
+ normalize_embeddings,
14
+ compute_similarity_and_loss,
15
+ RaclateBaseModel,
16
+ )
17
+
18
+ """
19
+ Not using cache in this implementation given
20
+ the model is intended to be used for embedding and classification tasks.
21
+ """
22
+
23
+ @dataclass
24
+ class ModelArgs(BaseModelArgs):
25
+ architectures: List[str] = field(default_factory=lambda: ["Lfm2Model"])
26
+ block_auto_adjust_ff_dim: bool = False
27
+ block_dim: int = 1024
28
+ block_ff_dim: int = 6656
29
+ block_ffn_dim_multiplier: float = 1.0
30
+ block_mlp_init_scale: Optional[float] = None
31
+ block_multiple_of: int = 256
32
+ block_norm_eps: float = 1e-5 # where to use this?
33
+ block_use_swiglu: bool = True # where to use this?
34
+ block_use_xavier_init: bool = True # where to use this?
35
+ bos_token_id: int = 1
36
+ conv_bias: bool = False
37
+ conv_L_cache: int = 3
38
+ conv_dim : int = 1024 # where to use this?
39
+ conv_dim_out : int = 1024 # where to use this?
40
+ conv_use_xavier_init: bool = True # where to use this?
41
+ eos_token_id: int = 7
42
+ full_attn_idxs: Optional[List[int]] = None
43
+ hidden_size: int = 1024
44
+ initializer_range: Optional[float] = (
45
+ 0.02 # Only needed in case of initializing weights
46
+ )
47
+ layer_types: Optional[List[str]] = None
48
+ max_position_embeddings: int = 128000
49
+ model_type: str = "lfm2"
50
+ norm_eps: float = 1e-05
51
+ num_attention_heads: int = 16
52
+ num_hidden_layers: int = 16
53
+ num_key_value_heads: int = 8
54
+ out_features: int = 128 # classifier output features
55
+ pad_token_id: int = 0
56
+ rope_theta: float = 1000000.0
57
+ vocab_size: int = 65536
58
+
59
+ ### pipeline args
60
+ decoder_bias=True,
61
+ classifier_dropout=0.0
62
+ classifier_bias=False
63
+ sparse_prediction=True ### True seems a more appropriate value for MLM
64
+ sparse_pred_ignore_index=-100
65
+ is_regression: Optional[bool] = None
66
+ label2id: Optional[Dict[str, int]] = None
67
+ id2label: Optional[Dict[int, str]] = None
68
+ pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
69
+ use_late_interaction: bool = False
70
+
71
+ @property
72
+ def num_labels(self) -> int:
73
+ """
74
+ Number of labels is determined by:
75
+ - For zero-shot classification: length of label_candidates
76
+ - For regression or binary with sigmoid: 1
77
+ - For classification: length of id2label mapping
78
+ """
79
+
80
+ if self.is_regression:
81
+ return 1
82
+
83
+ if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
84
+ return 1
85
+
86
+ if self.id2label is None:
87
+ raise ValueError(
88
+ "id2label mapping must be provided for categorical classification. "
89
+ "For regression or binary classification with sigmoid output, "
90
+ "set is_regression=True or binary_sigmoid=True in pipeline_config."
91
+ )
92
+
93
+ return len(self.id2label)
94
+
95
+
96
+ def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
97
+ """
98
+ Standardizes keys for the Gemma3 Backbone.
99
+ Prefixes generic keys with 'model.' and handles basic mapping.
100
+ """
101
+ sanitized = {}
102
+ for k, v in weights.items():
103
+ # Skip unrelated heads that might be in the checkpoint
104
+ if any(x in k for x in ["lm_head", "classifier"]):
105
+ # We don't automatically map these; specific models handle them if needed
106
+ continue
107
+
108
+ if "position_ids" in k:
109
+ # Remove unused position_ids
110
+ continue
111
+
112
+ if "conv.weight" in k:
113
+ if v.shape[-1] > v.shape[1]:
114
+ v = v.transpose(0, 2, 1)
115
+
116
+ # Handle potential non-prefixed weights
117
+ # not prefixing "\d+_Dense\.linear" enables futher processing in ModelForSentenceTransformer
118
+ if "Dense.linear" not in k and \
119
+ not k.startswith("model.") and \
120
+ not k.startswith("dense.") and \
121
+ not k.startswith("score.") :
122
+
123
+ new_key = f"model.{k}"
124
+
125
+ sanitized[new_key] = v
126
+ else:
127
+ sanitized[k] = v
128
+
129
+ return sanitized
130
+
131
+ class Attention(nn.Module):
132
+ def __init__(self, args: ModelArgs):
133
+ super().__init__()
134
+
135
+ dim = args.hidden_size
136
+ self.n_heads = n_heads = args.num_attention_heads
137
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
138
+
139
+ self.head_dim = head_dim = args.hidden_size // n_heads
140
+
141
+ self.scale = head_dim**-0.5
142
+
143
+ self.q_layernorm = nn.RMSNorm(head_dim, eps=args.norm_eps)
144
+ self.k_layernorm = nn.RMSNorm(head_dim, eps=args.norm_eps)
145
+
146
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
147
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
148
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
149
+ self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
150
+
151
+ self.rope = nn.RoPE(
152
+ self.head_dim,
153
+ base=args.rope_theta,
154
+ traditional=False,
155
+ )
156
+
157
+ def __call__(
158
+ self,
159
+ x: mx.array,
160
+ mask: Optional[mx.array] = None,
161
+ cache: Optional[Any] = None
162
+ ) -> mx.array:
163
+ B, L, D = x.shape
164
+
165
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
166
+
167
+ queries = self.q_layernorm(queries.reshape(B, L, self.n_heads, -1)).transpose(
168
+ 0, 2, 1, 3
169
+ )
170
+ keys = self.k_layernorm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(
171
+ 0, 2, 1, 3
172
+ )
173
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
174
+
175
+ queries = self.rope(queries)
176
+ keys = self.rope(keys)
177
+
178
+ output = mx.fast.scaled_dot_product_attention(
179
+ queries, keys, values, scale=self.scale, mask=mask
180
+ )
181
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
182
+ return self.out_proj(output)
183
+
184
+ class ShortConv(nn.Module):
185
+ def __init__(
186
+ self,
187
+ args: ModelArgs,
188
+ layer_idx: int,
189
+ ):
190
+ super().__init__()
191
+ self.args = args
192
+ self.layer_idx = layer_idx
193
+ self.L_cache = args.conv_L_cache
194
+ self.bias = args.conv_bias
195
+
196
+ self.conv = nn.Conv1d(
197
+ in_channels=args.hidden_size,
198
+ out_channels=args.hidden_size,
199
+ kernel_size=self.L_cache,
200
+ groups=args.hidden_size,
201
+ bias=self.bias,
202
+ )
203
+ self.in_proj = nn.Linear(args.hidden_size, 3 * args.hidden_size, bias=self.bias)
204
+ self.out_proj = nn.Linear(args.hidden_size, args.hidden_size, bias=self.bias)
205
+
206
+ def __call__(
207
+ self,
208
+ x: mx.array,
209
+ mask: Optional[mx.array] = None,
210
+ cache: Optional[Any] = None
211
+ ):
212
+ BCx = self.in_proj(x)
213
+ B, C, x = mx.split(BCx, 3, axis=-1)
214
+ Bx = B * x
215
+ if mask is not None:
216
+ Bx = mx.where(mask[..., None], Bx, 0)
217
+
218
+ state = mx.zeros(
219
+ (Bx.shape[0], self.L_cache - 1, self.args.hidden_size), dtype=Bx.dtype
220
+ )
221
+
222
+ Bx = mx.concatenate([state, Bx], axis=-2)
223
+ conv_out = self.conv(Bx)
224
+
225
+ y = C * conv_out
226
+ return self.out_proj(y)
227
+
228
+
229
+ class MLP(nn.Module):
230
+ def __init__(
231
+ self,
232
+ dim: int,
233
+ ff_dim: int,
234
+ multiple_of: int,
235
+ auto_adjust_ff_dim: bool,
236
+ ffn_dim_multiplier: Optional[float],
237
+ ):
238
+ super().__init__()
239
+ if auto_adjust_ff_dim:
240
+ ff_dim = int(2 * ff_dim / 3)
241
+ if ffn_dim_multiplier is not None:
242
+ ff_dim = int(ffn_dim_multiplier * ff_dim)
243
+ ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
244
+
245
+ self.w1 = nn.Linear(dim, ff_dim, bias=False)
246
+ self.w3 = nn.Linear(dim, ff_dim, bias=False)
247
+ self.w2 = nn.Linear(ff_dim, dim, bias=False)
248
+
249
+ def __call__(self, x) -> mx.array:
250
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
251
+
252
+
253
+ class Lfm2DecoderLayer(nn.Module):
254
+ def __init__(self, args: ModelArgs, layer_idx: int):
255
+ super().__init__()
256
+ if args.full_attn_idxs :
257
+ self.is_attention_layer = layer_idx in args.full_attn_idxs
258
+ elif args.layer_types:
259
+ self.is_attention_layer = args.layer_types[layer_idx] == "full_attention"
260
+ else:
261
+ raise ValueError("Either full_attn_idxs or layer_types must be provided in ModelArgs")
262
+
263
+ if self.is_attention_layer:
264
+ self.self_attn = Attention(args)
265
+ else:
266
+ self.conv = ShortConv(args, layer_idx)
267
+
268
+ self.feed_forward = MLP(
269
+ dim=args.block_dim,
270
+ ff_dim=args.block_ff_dim,
271
+ multiple_of=args.block_multiple_of,
272
+ auto_adjust_ff_dim=args.block_auto_adjust_ff_dim,
273
+ ffn_dim_multiplier=args.block_ffn_dim_multiplier,
274
+ )
275
+
276
+ self.operator_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
277
+ self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
278
+
279
+ def __call__(
280
+ self,
281
+ x: mx.array,
282
+ mask: Optional[mx.array] = None,
283
+ cache: Optional[Any] = None,
284
+ ) -> mx.array:
285
+
286
+ if self.is_attention_layer:
287
+ r = self.self_attn(self.operator_norm(x), mask=mask, cache=cache)
288
+ else:
289
+ r = self.conv(
290
+ self.operator_norm(x),
291
+ mask=mask,
292
+ cache=cache,
293
+ )
294
+ h = x + r
295
+ out = h + self.feed_forward(self.ffn_norm(h))
296
+ return (out,)
297
+
298
+ class Lfm2Model(nn.Module):
299
+ def __init__(self, args: ModelArgs):
300
+ super().__init__()
301
+ self.args = args
302
+ self.vocab_size = args.vocab_size
303
+ self.num_hidden_layers = args.num_hidden_layers
304
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
305
+ self.layers = [
306
+ Lfm2DecoderLayer(args, layer_idx=i) for i in range(args.num_hidden_layers)
307
+ ]
308
+
309
+ self.embedding_norm = nn.RMSNorm(args.hidden_size, eps=args.norm_eps)
310
+
311
+ self.conv_idx = 0
312
+ if args.full_attn_idxs:
313
+ for i in range(args.num_hidden_layers):
314
+ if i in args.full_attn_idxs:
315
+ self.conv_idx += 1
316
+ else:
317
+ break
318
+ elif args.layer_types:
319
+ for i in range(args.num_hidden_layers):
320
+ if args.layer_types[i] != "full_attention":
321
+ self.conv_idx += 1
322
+ else:
323
+ break
324
+ else:
325
+ raise ValueError("Either full_attn_idxs or layer_types must be provided in ModelArgs")
326
+
327
+ self.hf_transformers_arch = "Lfm2Model"
328
+
329
+ def get_input_embeddings(self) -> nn.Embedding:
330
+ return self.embed_tokens
331
+
332
+ def set_input_embeddings(self, value):
333
+ self.embed_tokens = value
334
+
335
+ def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
336
+ """
337
+ Creates a causal mask and combines it with the padding mask.
338
+ """
339
+
340
+ B, L = attention_mask.shape
341
+
342
+ causal_mask = mx.triu(mx.full((L, L), -1e9, dtype), k=1)
343
+
344
+ if attention_mask is not None:
345
+ # Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
346
+ padding_mask = attention_mask[:, None, None, :]
347
+ additive_padding_mask = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
348
+
349
+ causal_mask = causal_mask + additive_padding_mask
350
+
351
+ return causal_mask.astype(dtype)
352
+
353
+ def _create_ssm_mask(self, h, cache=None):
354
+ if cache and hasattr(cache, "make_mask"):
355
+ return cache.make_mask(h.shape[1])
356
+ return None
357
+
358
+ def __call__(
359
+ self,
360
+ input_ids: mx.array,
361
+ attention_mask: Optional[mx.array] = None,
362
+ ):
363
+
364
+ hidden_states = self.embed_tokens(input_ids)
365
+ model_dtype = hidden_states.dtype
366
+
367
+ cache = [None] * len(self.layers)
368
+
369
+ attn_mask = self._update_attention_mask(attention_mask, dtype=model_dtype)
370
+ conv_mask = self._create_ssm_mask(hidden_states, cache[self.conv_idx])
371
+
372
+ for layer, c in zip(self.layers, cache):
373
+ mask = attn_mask if layer.is_attention_layer else conv_mask
374
+ layer_outputs = layer(hidden_states, mask, cache=c)
375
+ hidden_states = layer_outputs[0]
376
+
377
+ hidden_states = self.embedding_norm(hidden_states)
378
+
379
+ return {
380
+ "last_hidden_state": hidden_states,
381
+ }
382
+
383
+
384
+ class Model(RaclateBaseModel):
385
+ def __init__(self, config: ModelArgs):
386
+ super().__init__()
387
+ self.config = config
388
+ self.model_type = config.model_type
389
+ self.model = Lfm2Model(config)
390
+
391
+ def __call__(
392
+ self,
393
+ input_ids: mx.array,
394
+ position_ids: Optional[mx.array] = None,
395
+ attention_mask: Optional[mx.array] = None,
396
+ output_hidden_states: Optional[bool] = False,
397
+ return_dict: Optional[bool] = True,
398
+ ):
399
+
400
+ if attention_mask is None:
401
+ batch_size, seq_len = input_ids.shape
402
+ attention_mask = mx.ones(
403
+ (batch_size, seq_len),
404
+ dtype=self.model.embed_tokens.weight.dtype,
405
+ )
406
+
407
+ out = self.model(input_ids, attention_mask)
408
+
409
+ last_hidden_state = (
410
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
411
+ )
412
+
413
+ # LFM2 is a causal model, so we use last token pooling for embeddings
414
+ text_embeds = last_token_pooling(last_hidden_state, attention_mask)
415
+ text_embeds = normalize_embeddings(text_embeds)
416
+
417
+ if not return_dict:
418
+ return (text_embeds, last_hidden_state)
419
+
420
+ return {
421
+ "embeddings": text_embeds, # normalized embeddings
422
+ "last_hidden_state": last_hidden_state,
423
+ }
424
+
425
+
426
+ def sanitize(self, weights):
427
+ sanitized_weights = _sanitize_backbone(weights)
428
+ sanitized = {}
429
+ for k, v in sanitized_weights.items():
430
+ if not k.startswith("model."):
431
+ continue
432
+ sanitized[k] = v
433
+ return sanitized
434
+
435
+
436
+ class ModelForSentenceSimilarity(RaclateBaseModel):
437
+ """
438
+ Computes similarity scores between input sequences and reference sentences.
439
+ """
440
+ def __init__(self, config: ModelArgs):
441
+ super().__init__()
442
+ self.config = config
443
+ self.model_type = config.model_type
444
+ self.model = Lfm2Model(config)
445
+ self.dense = [
446
+ nn.Linear(config.block_dim, config.out_features, bias=False),
447
+ ]
448
+
449
+ def _call_model(self, input_ids, attention_mask=None, return_dict=True):
450
+ out = self.model(input_ids, attention_mask)
451
+ last_hidden_state = (
452
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
453
+ )
454
+
455
+ for dense in self.dense:
456
+ last_hidden_state = dense(last_hidden_state)
457
+
458
+ # text_embeds = normalize_embeddings(last_hidden_state)
459
+ if self.config.use_late_interaction:
460
+ text_embeds = normalize_embeddings(last_hidden_state)
461
+ # Keep unpooled for ColBERT style
462
+ # Mask padding tokens to avoid them affecting MaxSim
463
+ if attention_mask is not None:
464
+ text_embeds = text_embeds * attention_mask[..., None]
465
+ else:
466
+ # Standard dense retrieval: Mean Pooling
467
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
468
+ text_embeds = normalize_embeddings(text_embeds)
469
+
470
+
471
+ if not return_dict:
472
+ return (text_embeds, last_hidden_state)
473
+
474
+ return {
475
+ "embeddings": text_embeds, # normalized embeddings
476
+ "last_hidden_state": last_hidden_state,
477
+ }
478
+
479
+ def __call__(
480
+ self,
481
+ input_ids,
482
+ reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
483
+ negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
484
+ attention_mask: Optional[mx.array] = None,
485
+ reference_attention_mask: Optional[mx.array] = None,
486
+ negative_attention_mask: Optional[mx.array] = None,
487
+ similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
488
+ position_ids: Optional[mx.array] = None,
489
+ return_dict: Optional[bool] = True,
490
+ ):
491
+ if attention_mask is None:
492
+ batch_size, seq_len = input_ids.shape
493
+ attention_mask = mx.ones(
494
+ (batch_size, seq_len),
495
+ dtype=self.model.embed_tokens.weight.dtype,
496
+ )
497
+
498
+ # Get embeddings for input batch
499
+ batch_outputs = self._call_model(
500
+ input_ids=input_ids,
501
+ attention_mask=attention_mask,
502
+ return_dict=True
503
+ )
504
+ embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
505
+
506
+ loss = None
507
+ similarities = None
508
+ if reference_input_ids is not None:
509
+
510
+ # Get embeddings for reference sentences
511
+ ref_outputs = self._call_model(
512
+ input_ids=reference_input_ids,
513
+ attention_mask=reference_attention_mask,
514
+ return_dict=True
515
+ )
516
+ reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
517
+
518
+ similarities, loss = compute_similarity_and_loss(
519
+ self.config,
520
+ input_ids,
521
+ embeddings,
522
+ reference_embeddings,
523
+ self._call_model,
524
+ similarity_scores,
525
+ negative_input_ids,
526
+ negative_attention_mask
527
+ )
528
+
529
+ if not return_dict:
530
+ return (loss, similarities, embeddings)
531
+
532
+ return {
533
+ "loss": loss,
534
+ "similarities": similarities, # [batch_size, num_references]
535
+ "embeddings": embeddings, # [batch_size, hidden_size]
536
+ }
537
+
538
+ def sanitize(self, weights):
539
+ sanitized_weights = _sanitize_backbone(weights)
540
+ sanitized = {}
541
+ for k, v in sanitized_weights.items():
542
+ if not k.startswith("model.") and not k.startswith("dense."):
543
+ continue
544
+ sanitized[k] = v
545
+ return sanitized
546
+
547
+ class ModelForSentenceTransformers(ModelForSentenceSimilarity):
548
+ """
549
+ Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
550
+ This class sanitizes typical sentence transformers weights to align with the T5Gemma model.
551
+ """
552
+ def __init__(self, config: ModelArgs):
553
+ super().__init__(config)
554
+
555
+ def sanitize(self, weights):
556
+ """Convert sentence transformer weights to T5Gemma format."""
557
+ sanitized = _sanitize_backbone(weights)
558
+
559
+ sanitized_weights = {}
560
+ for k, v in sanitized.items():
561
+ if "1_Dense.linear" in k:
562
+ new_key = k.replace("1_Dense.linear", "dense.0")
563
+ sanitized_weights[new_key] = v
564
+ elif k.startswith("model.") or k.startswith("dense."):
565
+ sanitized_weights[k] = v
566
+ else:
567
+ continue
568
+ return sanitized_weights
569
+
570
+ class ModelForSequenceClassification(RaclateBaseModel):
571
+ """
572
+ Computes sequence classification probabilities for input sequences.
573
+
574
+ NOTE : regression and binary classification not tested.
575
+ """
576
+ def __init__(self, config: ModelArgs):
577
+ super().__init__()
578
+ self.config = config
579
+ self.num_labels = config.num_labels
580
+ self.is_regression = config.is_regression
581
+
582
+ self.model = Lfm2Model(config)
583
+
584
+ # No HF transformers architecture SequenceClassification typically only as a score layer
585
+ self.score = nn.Linear(
586
+ config.hidden_size,
587
+ config.num_labels,
588
+ bias=False
589
+ )
590
+
591
+ # No HF transformers architecture for LFM2 and SequenceClassification
592
+
593
+ def _process_outputs(self, logits: mx.array) -> mx.array:
594
+ """Apply the appropriate activation function to the logits."""
595
+ if self.is_regression:
596
+ return logits # No activation for regression
597
+ elif self.num_labels == 1:
598
+ return mx.sigmoid(logits) # Binary classification
599
+ else:
600
+ # Using softmax for multi-class classification
601
+ return mx.softmax(logits, axis=-1)
602
+
603
+ def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
604
+ """Compute the appropriate loss based on label characteristics."""
605
+ if self.is_regression:
606
+ return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
607
+ elif self.num_labels == 1:
608
+ return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
609
+ else:
610
+ return nn.losses.cross_entropy(
611
+ logits.reshape(-1, self.num_labels),
612
+ labels.reshape(-1)
613
+ )
614
+
615
+ def __call__(
616
+ self,
617
+ input_ids,
618
+ attention_mask: Optional[mx.array] = None,
619
+ position_ids: Optional[mx.array] = None, ### need this?
620
+ labels: Optional[mx.array] = None,
621
+ output_hidden_states: Optional[bool] = False,
622
+ return_dict: Optional[bool] = True,
623
+ ) -> Dict:
624
+ if attention_mask is None:
625
+ batch_size, seq_len = input_ids.shape
626
+ attention_mask = mx.ones(
627
+ (batch_size, seq_len),
628
+ dtype=self.model.embed_tokens.weight.dtype,
629
+ )
630
+
631
+ outputs = self.model(
632
+ input_ids,
633
+ attention_mask
634
+ )
635
+ last_hidden_state = (
636
+ outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
637
+ )
638
+
639
+ # pooling for AR models such as LFM2 leverages the last token
640
+ pooled = last_token_pooling(last_hidden_state, attention_mask)
641
+
642
+ ### The HF architecture for SequenceClassification typically only has a score layer
643
+ logits = self.score(pooled)
644
+
645
+ processed_logits = self._process_outputs(logits)
646
+
647
+ loss = None
648
+ if labels is not None :
649
+ loss = self._compute_loss(logits, labels)
650
+
651
+ if not return_dict:
652
+ return [loss, processed_logits, outputs[1:]]
653
+
654
+ return {
655
+ "loss": loss,
656
+ "probabilities": processed_logits,
657
+ "hidden_states": outputs.get("hidden_states", None),
658
+ }
659
+
660
+ def sanitize(self, weights):
661
+ sanitized_weights = _sanitize_backbone(weights)
662
+ sanitized = {}
663
+ for k, v in sanitized_weights.items():
664
+ if not k.startswith("model.") and not k.startswith("score."):
665
+ continue
666
+ sanitized[k] = v
667
+ return sanitized
668
+
669
+
670
+ # TokenClassification and MaskedLM not implemented for now AR models such as LFM2
671
+ # Attempting to train pretrained weights would be catastrophic