mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,857 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Union, Any, Literal
3
+ import re
4
+ from functools import partial
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from .base import (
10
+ BaseModelArgs,
11
+ mean_pooling,
12
+ normalize_embeddings,
13
+ compute_similarity_and_loss,
14
+ RaclateBaseModel,
15
+ )
16
+
17
+ @dataclass
18
+ class ModelArgs(BaseModelArgs):
19
+ architectures: List[str] = field(default_factory=lambda: ["T5GemmaForConditionalGeneration"])
20
+ attention_bias: Optional[bool] = False
21
+ attention_dropout: Optional[float] = 0.0
22
+ attn_logit_softcapping: Optional[float] = None # Not supported with sdpa
23
+ bos_token_id: Optional[int] = None
24
+ dropout_rate: float = 0.0
25
+ eos_token_id: Optional[List[int]] = None
26
+ final_logit_softcapping: Optional[float] = None # Not supported with sdpa
27
+ head_dim: int = 64
28
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh"
29
+ hidden_size: int = 768
30
+ initializer_range: Optional[float] = 0.02 # Only needed in case of initializing weights
31
+ intermediate_size: int = 2048
32
+ is_causal: bool = False
33
+ is_encoder_decoder: bool = False
34
+ layer_types: List[str] = field(default_factory=list)
35
+ max_position_embeddings: int = 8192
36
+ model_type: str = "t5gemma"
37
+ num_attention_heads: int = 12
38
+ num_hidden_layers: int = 12
39
+ num_key_value_heads: int = 12
40
+ query_pre_attn_scalar: float = 64
41
+ rms_norm_eps: float = 1.0e-6
42
+ rope_theta: float = 10000.0
43
+ rope_traditional: bool = False
44
+ sliding_window: int = 4096
45
+ vocab_size: int = 256000
46
+
47
+
48
+ ### pipeline args
49
+ decoder_bias=True,
50
+ classifier_dropout_rate: float = 0.0
51
+ classifier_bias=False
52
+ norm_bias : bool = False
53
+ norm_eps: float = 1e-05
54
+ sparse_prediction=True ### True seems a more appropriate value for MLM
55
+ sparse_pred_ignore_index=-100
56
+ is_regression: Optional[bool] = None
57
+ label2id: Optional[Dict[str, int]] = None
58
+ id2label: Optional[Dict[int, str]] = None
59
+ pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
60
+ use_late_interaction: bool = False
61
+
62
+ @property
63
+ def num_labels(self) -> int:
64
+ """
65
+ Number of labels is determined by:
66
+ - For zero-shot classification: length of label_candidates
67
+ - For regression or binary with sigmoid: 1
68
+ - For classification: length of id2label mapping
69
+ """
70
+
71
+ if self.is_regression:
72
+ return 1
73
+
74
+ if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
75
+ return 1
76
+
77
+ if self.id2label is None:
78
+ raise ValueError(
79
+ "id2label mapping must be provided for categorical classification. "
80
+ "For regression or binary classification with sigmoid output, "
81
+ "set is_regression=True or binary_sigmoid=True in pipeline_config."
82
+ )
83
+
84
+ return len(self.id2label)
85
+
86
+ def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
87
+ """
88
+ Standardizes keys for the T5Gemma Encoder Backbone.
89
+ Drops Decoder weights and handles basic mapping.
90
+ """
91
+ sanitized_weights = {}
92
+ for k, v in weights.items():
93
+ # Skip unused buffers or position IDs if present
94
+ if "position_ids" in k or "rotary_emb.inv_freq" in k:
95
+ continue
96
+
97
+ # Skip decoder weights
98
+ if k.startswith("model.decoder."):
99
+ continue
100
+
101
+ # if the model uses shared embeddings, map them correctly
102
+ if k == "shared.weight":
103
+ if "encoder.embed_tokens.weight" not in weights:
104
+ sanitized_weights["model.embed_tokens.weight"] = v
105
+ continue
106
+
107
+ # Process Encoder weights
108
+ if k.startswith("model.encoder."):
109
+ new_k = k.replace("model.encoder.", "model.")
110
+ # handle non-prefix keys but must keep
111
+ # "score" for classification and "head" and "decoder" for masked lm
112
+ elif not k.startswith("model.") and not k.startswith("score.") and not k.startswith("head.") and not k.startswith("decoder."):
113
+ new_k = f"model.{k}"
114
+ else:
115
+ new_k = k
116
+
117
+ sanitized_weights[new_k] = v
118
+
119
+ return sanitized_weights
120
+
121
+ class Attention(nn.Module):
122
+ def __init__(self, args: ModelArgs, layer_idx: int):
123
+ super().__init__()
124
+
125
+ dim = args.hidden_size
126
+ self.n_heads = n_heads = args.num_attention_heads
127
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
128
+ self.repeats = n_heads // n_kv_heads
129
+ self.head_dim = head_dim = args.head_dim
130
+ self.layer_idx = layer_idx
131
+
132
+ self.scale = args.query_pre_attn_scalar**-0.5
133
+
134
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
135
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
136
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
137
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
138
+
139
+ layer_type = args.layer_types[layer_idx] if args.layer_types else None
140
+ self.is_sliding = layer_type == "sliding_window"
141
+
142
+ base = (
143
+ args.rope_theta
144
+ )
145
+
146
+ self.rope = nn.RoPE(
147
+ head_dim,
148
+ traditional=args.rope_traditional,
149
+ base=base,
150
+ )
151
+
152
+ # softcapping support
153
+ self.attn_logit_softcapping = args.attn_logit_softcapping
154
+
155
+
156
+ def __call__(
157
+ self,
158
+ x: mx.array,
159
+ mask: Optional[mx.array] = None
160
+ ) -> mx.array:
161
+ B, L, _ = x.shape
162
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
163
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
164
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
165
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
166
+
167
+ queries = self.rope(queries)
168
+ keys = self.rope(keys)
169
+
170
+ if self.attn_logit_softcapping is None:
171
+ output = mx.fast.scaled_dot_product_attention(
172
+ queries, keys, values, scale=self.scale, mask=mask
173
+ )
174
+ else:
175
+ queries = queries * self.scale
176
+
177
+ attn_weights = mx.matmul(queries, keys.transpose(0, 1, 3, 2))
178
+ cap = self.attn_logit_softcapping
179
+ attn_weights = mx.tanh(attn_weights / cap) * cap
180
+
181
+ if mask is not None:
182
+ attn_weights = attn_weights + mask
183
+
184
+ attn_weights = mx.softmax(attn_weights, axis=-1)
185
+ output = mx.matmul(attn_weights, values)
186
+
187
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
188
+ return self.o_proj(output)
189
+
190
+ class RMSNorm(nn.Module):
191
+ def __init__(self, dims: int, eps: float = 1e-5):
192
+ super().__init__()
193
+ self.weight = mx.ones((dims,))
194
+ self.eps = eps
195
+
196
+ def __call__(self, x):
197
+ return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
198
+
199
+ class MLP(nn.Module):
200
+ def __init__(self, dim, hidden_dim):
201
+ super().__init__()
202
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
203
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
204
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
205
+
206
+ def __call__(self, x) -> mx.array:
207
+ return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
208
+
209
+ class TransformerBlock(nn.Module):
210
+ def __init__(self, args: ModelArgs, layer_idx: int):
211
+ super().__init__()
212
+ self.num_attention_heads = args.num_attention_heads
213
+ self.hidden_size = args.hidden_size
214
+ self.self_attn = Attention(args, layer_idx)
215
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
216
+ self.pre_self_attn_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
217
+ self.post_self_attn_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
218
+ self.pre_feedforward_layernorm = RMSNorm(
219
+ args.hidden_size, eps=args.rms_norm_eps
220
+ )
221
+ self.post_feedforward_layernorm = RMSNorm(
222
+ args.hidden_size, eps=args.rms_norm_eps
223
+ )
224
+ self.dropout = nn.Dropout(args.dropout_rate)
225
+
226
+ def __call__(
227
+ self,
228
+ x: mx.array,
229
+ mask: Optional[mx.array] = None
230
+ ) -> mx.array:
231
+ r = x
232
+ h = self.self_attn(self.pre_self_attn_layernorm(x), mask)
233
+ h = self.post_self_attn_layernorm(h)
234
+ h = r + self.dropout(h)
235
+ r = h
236
+ h= self.mlp(self.pre_feedforward_layernorm(h))
237
+ out = self.post_feedforward_layernorm(h)
238
+ out = r + self.dropout(out)
239
+ return (out,)
240
+
241
+ class T5GemmaEncoder(nn.Module):
242
+ def __init__(self, config: ModelArgs):
243
+ super().__init__()
244
+ self.config = config
245
+ self.vocab_size = config.vocab_size
246
+ self.num_hidden_layers = config.num_hidden_layers
247
+ assert self.vocab_size > 0
248
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
249
+ self.layers = [
250
+ TransformerBlock(config, layer_idx=i)
251
+ for i in range(config.num_hidden_layers)
252
+ ]
253
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
254
+ self.dropout = nn.Dropout(config.dropout_rate)
255
+
256
+ def get_input_embeddings(self) -> nn.Embedding:
257
+ return self.embed_tokens
258
+
259
+ def set_input_embeddings(self, value):
260
+ self.embed_tokens = value
261
+
262
+ def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
263
+ """
264
+ Creates a causal mask and combines it with the padding mask.
265
+ """
266
+
267
+ B, L = attention_mask.shape
268
+ window_size = self.config.sliding_window
269
+ indices = mx.arange(L)
270
+ row = indices[:, None]
271
+ col = row.T
272
+
273
+ if not self.config.is_causal:
274
+ mask_base = mx.zeros((L, L), dtype=mx.bool_) # All False (visible)
275
+
276
+ # Sliding Window Logic for Bidirectional:
277
+ # Valid if abs(row - col) < window
278
+ # Mask if distance >= window
279
+ dist = mx.abs(row - col)
280
+ mask_window_violator = dist >= window_size
281
+
282
+ # In practice, T5Gemma Encoder is always non-causal but we keep the code
283
+ # from other models for consistency
284
+ else:
285
+ # Causal: Standard triangular mask
286
+ mask_future = col > row
287
+ mask_base = mask_future
288
+
289
+ # Sliding Window Logic for Causal:
290
+ # Valid if row - col < window (and not future)
291
+ # Mask if (row - col) >= window
292
+ mask_past = (row - col) >= window_size
293
+ mask_window_violator = mask_past
294
+
295
+ global_mask = mx.where(mask_base, -1e9, 0.0).astype(dtype)
296
+ sliding_mask_bool = mask_base | mask_window_violator
297
+ sliding_mask = mx.where(sliding_mask_bool, -1e9, 0.0).astype(dtype)
298
+
299
+ # Padding Mask
300
+ if attention_mask is not None:
301
+ # Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
302
+ padding_mask = attention_mask[:, None, None, :]
303
+ additive_padding = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
304
+
305
+ global_mask = global_mask + additive_padding
306
+ sliding_mask = sliding_mask + additive_padding
307
+
308
+ return global_mask, sliding_mask
309
+
310
+ def __call__(
311
+ self,
312
+ input_ids: mx.array,
313
+ attention_mask: Optional[mx.array] = None,
314
+ output_hidden_states: Optional[bool] = False,
315
+ position_ids: Optional[mx.array] = None,
316
+ return_dict: Optional[bool] = True
317
+ ):
318
+ hidden_states = self.embed_tokens(input_ids)
319
+ model_dtype = hidden_states.dtype
320
+
321
+ # normalizer
322
+ hidden_states *= mx.array(self.config.hidden_size**0.5, model_dtype)
323
+
324
+ hidden_states = self.dropout(hidden_states)
325
+
326
+ global_mask, sliding_window_mask = self._update_attention_mask(
327
+ attention_mask,
328
+ dtype=model_dtype
329
+ )
330
+
331
+ for i, layer in enumerate(self.layers):
332
+ is_global = self.config.layer_types[i] == "full_attention"
333
+ layer_mask = global_mask if is_global else sliding_window_mask
334
+ layer_outputs = layer(hidden_states, layer_mask)
335
+ hidden_states = layer_outputs[0]
336
+
337
+ hidden_states = self.norm(hidden_states)
338
+ hidden_states = self.dropout(hidden_states)
339
+
340
+ return {
341
+ "last_hidden_state": hidden_states,
342
+ }
343
+
344
+
345
+ class T5GemmaClassificationHead(nn.Module):
346
+ def __init__(self, config: ModelArgs):
347
+ super().__init__()
348
+ self.config = config
349
+ self.dropout = nn.Dropout(p=config.classifier_dropout_rate)
350
+ self.out_proj = nn.Linear(
351
+ config.hidden_size, config.num_labels
352
+ )
353
+ self.soft_cap = config.final_logit_softcapping
354
+
355
+ def __call__(self, hidden_states: mx.array) -> mx.array:
356
+ logits = self.out_proj(self.dropout(hidden_states))
357
+ if self.soft_cap is not None:
358
+ logits = mx.tanh(logits / self.soft_cap) * self.soft_cap
359
+ return logits
360
+
361
+ class T5GemmaPredictionHead(nn.Module):
362
+ def __init__(self, config: ModelArgs):
363
+ super().__init__()
364
+ self.config = config
365
+ self.dense = nn.Linear(
366
+ config.hidden_size, config.hidden_size, bias=False
367
+ )
368
+ self.act = nn.GELU()
369
+ self.layer_norm = nn.LayerNorm(config.hidden_size, bias=config.norm_bias, eps=config.norm_eps)
370
+
371
+
372
+ def __call__(self, hidden_states: mx.array) -> mx.array:
373
+ return self.layer_norm(self.act(self.dense(hidden_states)))
374
+
375
+
376
+ class Model(RaclateBaseModel):
377
+ def __init__(self, config: ModelArgs):
378
+ super().__init__()
379
+ self.config = config
380
+ self.model_type = config.model_type
381
+ self.model = T5GemmaEncoder(config)
382
+
383
+ # transformer architecture name for compatibility
384
+ self.hf_transformers_arch = ""
385
+
386
+ def __call__(
387
+ self,
388
+ input_ids: mx.array,
389
+ position_ids: Optional[mx.array] = None,
390
+ attention_mask: Optional[mx.array] = None,
391
+ output_hidden_states: Optional[bool] = False,
392
+ return_dict: Optional[bool] = True,
393
+ ):
394
+
395
+ if attention_mask is None:
396
+ batch_size, seq_len = input_ids.shape
397
+ attention_mask = mx.ones(
398
+ (batch_size, seq_len),
399
+ dtype=self.model.embed_tokens.weight.dtype,
400
+ )
401
+
402
+ out = self.model(input_ids, attention_mask)
403
+ last_hidden_state = (
404
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
405
+ )
406
+
407
+ # normalized features
408
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
409
+ text_embeds = normalize_embeddings(text_embeds)
410
+
411
+ if not return_dict:
412
+ return (text_embeds, last_hidden_state)
413
+
414
+ return {
415
+ "embeddings": text_embeds, # normalized embeddings
416
+ "last_hidden_state": last_hidden_state,
417
+ }
418
+
419
+ def sanitize(self, weights):
420
+ sanitized = _sanitize_backbone(weights)
421
+
422
+ sanitized_weights = {}
423
+ for k, v in sanitized.items():
424
+ if not k.startswith("model."):
425
+ continue
426
+ sanitized_weights[k] = v
427
+
428
+ return sanitized_weights
429
+
430
+
431
+ class ModelForSentenceSimilarity(RaclateBaseModel):
432
+ """
433
+ Computes similarity scores between input sequences and reference sentences.
434
+ """
435
+ def __init__(self, config: ModelArgs):
436
+ super().__init__()
437
+ self.config = config
438
+ self.model_type = config.model_type
439
+ self.model = T5GemmaEncoder(config)
440
+
441
+ def _call_model(
442
+ self,
443
+ input_ids: mx.array,
444
+ position_ids: Optional[mx.array] = None,
445
+ attention_mask: Optional[mx.array] = None,
446
+ output_hidden_states: Optional[bool] = False,
447
+ return_dict: Optional[bool] = True,
448
+ ):
449
+ out = self.model(input_ids, attention_mask)
450
+ last_hidden_state = (
451
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
452
+ )
453
+
454
+ # text_embeds = normalize_embeddings(last_hidden_state)
455
+ if self.config.use_late_interaction:
456
+ text_embeds = normalize_embeddings(last_hidden_state)
457
+ # Keep unpooled for ColBERT style
458
+ # Mask padding tokens to avoid them affecting MaxSim
459
+ if attention_mask is not None:
460
+ text_embeds = text_embeds * attention_mask[..., None]
461
+ else:
462
+ # Standard dense retrieval: Mean Pooling
463
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
464
+ text_embeds = normalize_embeddings(text_embeds)
465
+
466
+ if not return_dict:
467
+ return (text_embeds, last_hidden_state)
468
+
469
+ return {
470
+ "embeddings": text_embeds, # normalized embeddings
471
+ "last_hidden_state": last_hidden_state,
472
+ }
473
+
474
+ def __call__(
475
+ self,
476
+ input_ids,
477
+ reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
478
+ negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
479
+ attention_mask: Optional[mx.array] = None,
480
+ reference_attention_mask: Optional[mx.array] = None,
481
+ negative_attention_mask: Optional[mx.array] = None,
482
+ similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
483
+ position_ids: Optional[mx.array] = None,
484
+ return_dict: Optional[bool] = True,
485
+ ):
486
+ if attention_mask is None:
487
+ batch_size, seq_len = input_ids.shape
488
+ attention_mask = mx.ones(
489
+ (batch_size, seq_len),
490
+ dtype=self.model.embed_tokens.weight.dtype,
491
+ )
492
+ # Get embeddings for input batch
493
+ batch_outputs = self._call_model(
494
+ input_ids=input_ids,
495
+ attention_mask=attention_mask,
496
+ position_ids=position_ids,
497
+ return_dict=True
498
+ )
499
+ embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
500
+
501
+ loss = None
502
+ similarities = None
503
+ if reference_input_ids is not None:
504
+
505
+ # Get embeddings for reference sentences
506
+ ref_outputs = self._call_model(
507
+ input_ids=reference_input_ids,
508
+ attention_mask=reference_attention_mask,
509
+ position_ids=position_ids, ### ?
510
+ return_dict=True
511
+ )
512
+ reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
513
+
514
+ similarities, loss = compute_similarity_and_loss(
515
+ self.config,
516
+ input_ids,
517
+ embeddings,
518
+ reference_embeddings,
519
+ self._call_model,
520
+ similarity_scores,
521
+ negative_input_ids,
522
+ negative_attention_mask,
523
+ )
524
+
525
+ if not return_dict:
526
+ return (loss, similarities, embeddings)
527
+
528
+ return {
529
+ "loss": loss,
530
+ "similarities": similarities, # [batch_size, num_references]
531
+ "embeddings": embeddings, # [batch_size, hidden_size]
532
+ }
533
+
534
+ def sanitize(self, weights):
535
+ sanitized = _sanitize_backbone(weights)
536
+
537
+ sanitized_weights = {}
538
+ for k, v in sanitized.items():
539
+ if not k.startswith("model."):
540
+ continue
541
+ sanitized_weights[k] = v
542
+
543
+ return sanitized_weights
544
+
545
+ class ModelForSentenceTransformers(ModelForSentenceSimilarity):
546
+ """
547
+ Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
548
+ This class sanitizes typical sentence transformers weights to align with the T5Gemma model.
549
+ """
550
+ def __init__(self, config: ModelArgs):
551
+ super().__init__(config)
552
+
553
+
554
+ class ModelForSequenceClassification(RaclateBaseModel):
555
+ """
556
+ Computes sequence classification probabilities for input sequences.
557
+ Sanitization aligns typical BERT weights with HF's Qwen3ForSequenceClassification architecture.
558
+
559
+ NOTE : regression and binary classification not tested.
560
+ """
561
+ def __init__(self, config: ModelArgs):
562
+ super().__init__()
563
+ self.config = config
564
+ self.num_labels = config.num_labels
565
+ self.is_regression = config.is_regression
566
+
567
+ self.model = T5GemmaEncoder(config)
568
+
569
+ ### The HF architecture Gemma3ForSequenceClassification
570
+ ### does not have head and drop
571
+ #### and uses 'score' as the final layer name
572
+ # self.head = Gemma3PredictionHead(config)
573
+ # self.drop = nn.Dropout(p=config.classifier_dropout)
574
+
575
+ self.score = T5GemmaClassificationHead(config)
576
+
577
+ self.hf_transformers_arch = "T5GemmaForSequenceClassification"
578
+
579
+ def _process_outputs(self, logits: mx.array) -> mx.array:
580
+ """Apply the appropriate activation function to the logits."""
581
+ if self.is_regression:
582
+ return logits # No activation for regression
583
+ elif self.num_labels == 1:
584
+ return mx.sigmoid(logits) # Binary classification
585
+ else:
586
+ # Using softmax for multi-class classification
587
+ return mx.softmax(logits, axis=-1)
588
+
589
+ def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
590
+ """Compute the appropriate loss based on label characteristics."""
591
+ if self.is_regression:
592
+ return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
593
+ elif self.num_labels == 1:
594
+ return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
595
+ else:
596
+ return nn.losses.cross_entropy(
597
+ logits.reshape(-1, self.num_labels),
598
+ labels.reshape(-1)
599
+ )
600
+
601
+ def __call__(
602
+ self,
603
+ input_ids,
604
+ attention_mask: Optional[mx.array] = None,
605
+ position_ids: Optional[mx.array] = None, ### need this?
606
+ labels: Optional[mx.array] = None,
607
+ output_hidden_states: Optional[bool] = False,
608
+ return_dict: Optional[bool] = True,
609
+ ) -> Dict:
610
+ if attention_mask is None:
611
+ batch_size, seq_len = input_ids.shape
612
+ attention_mask = mx.ones(
613
+ (batch_size, seq_len),
614
+ dtype=self.model.embed_tokens.weight.dtype,
615
+ )
616
+
617
+ outputs = self.model(
618
+ input_ids,
619
+ attention_mask,
620
+ position_ids=position_ids,
621
+ output_hidden_states=output_hidden_states,
622
+ return_dict=return_dict
623
+ )
624
+
625
+ last_hidden_state = (
626
+ outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
627
+ )
628
+
629
+ # normalized features
630
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
631
+
632
+ ### The HF architecture T5GemmaForSequenceClassification
633
+ logits = self.score(text_embeds)
634
+
635
+ processed_logits = self._process_outputs(logits)
636
+
637
+ loss = None
638
+ if labels is not None :
639
+ loss = self._compute_loss(logits, labels)
640
+
641
+ if not return_dict:
642
+ return [loss, processed_logits, outputs[1:]]
643
+
644
+ return {
645
+ "loss": loss,
646
+ "probabilities": processed_logits,
647
+ "hidden_states": outputs.get("hidden_states", None),
648
+ }
649
+
650
+ def sanitize(self, weights):
651
+
652
+ sanitized = _sanitize_backbone(weights)
653
+
654
+ sanitized_weights = {}
655
+ for k, v in sanitized.items():
656
+ if not k.startswith("model.") and not k.startswith("score."):
657
+ continue
658
+ sanitized_weights[k] = v
659
+
660
+ return sanitized_weights
661
+
662
+ class ModelForMaskedLM(RaclateBaseModel):
663
+ """
664
+ Computes masked language modeling (MLM) loss for input sequences.
665
+ """
666
+ def __init__(self, config : ModelArgs):
667
+ super().__init__()
668
+ self.config = config
669
+ if config.is_causal:
670
+ raise ValueError("ModelForMaskedLM requires bidirectional attention.")
671
+ self.model = T5GemmaEncoder(config)
672
+ self.head = T5GemmaPredictionHead(config)
673
+ self.decoder = nn.Linear(
674
+ config.hidden_size, config.vocab_size, bias=config.decoder_bias
675
+ )
676
+
677
+ # transformers has no MaskedLM class for T5Gemma
678
+
679
+ # We explicitly call tie_weights to ensure logic is set up,
680
+ # though standard loading overwrites this unless sanitized correctly.
681
+ self.tie_weights()
682
+
683
+ def tie_weights(self):
684
+ self.decoder.weight = self.model.embed_tokens.weight
685
+
686
+ def get_input_embeddings(self):
687
+ return self.model.get_input_embeddings()
688
+
689
+ def get_output_embeddings(self):
690
+ return self.decoder
691
+
692
+ def set_input_embeddings(self, value):
693
+ self.model.set_input_embeddings(value)
694
+ self.tie_weights() # Re-tie weights after setting new embeddings
695
+
696
+ def set_output_embeddings(self, new_embeddings):
697
+ self.decoder = new_embeddings
698
+ self.tie_weights() # Re-tie weights after setting new decoder
699
+
700
+ def __call__(
701
+ self,
702
+ input_ids,
703
+ attention_mask: Optional[mx.array] = None,
704
+ labels: Optional[mx.array] = None,
705
+ position_ids: Optional[mx.array] = None,
706
+ output_hidden_states: Optional[bool] = None,
707
+ return_dict: Optional[bool] = True,
708
+ ) -> Dict:
709
+
710
+ if attention_mask is None:
711
+ batch_size, seq_len = input_ids.shape
712
+ attention_mask = mx.ones((batch_size, seq_len)) ### updated via _update_attention_mask() in the model
713
+
714
+ outputs = self.model(
715
+ input_ids=input_ids,
716
+ attention_mask=attention_mask,
717
+ position_ids=position_ids,
718
+ output_hidden_states=output_hidden_states,
719
+ return_dict=return_dict,
720
+ )
721
+
722
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
723
+ logits = self.head(last_hidden_state)
724
+ logits = self.decoder(logits)
725
+
726
+ loss = None
727
+ if self.training and labels is not None :
728
+ if getattr(self.config, "sparse_prediction", False):
729
+ # Flatten labels and predictions
730
+ flat_labels = labels.reshape(-1)
731
+ flat_predictions = logits.reshape(-1, logits.shape[-1])
732
+
733
+ # Filter out non-masked tokens
734
+ ignore_index = getattr(self.config, "sparse_pred_ignore_index", -100)
735
+ mask_tokens = flat_labels != ignore_index
736
+
737
+ # Only compute loss on masked tokens
738
+ masked_predictions = flat_predictions[mask_tokens]
739
+ masked_labels = flat_labels[mask_tokens]
740
+
741
+ loss = nn.losses.cross_entropy(
742
+ masked_predictions,
743
+ masked_labels,
744
+ reduction='mean'
745
+ )
746
+ else:
747
+ # Standard loss computation on all tokens
748
+ loss = nn.losses.cross_entropy(
749
+ logits.reshape(-1, logits.shape[-1]),
750
+ labels.reshape(-1),
751
+ reduction='mean'
752
+ )
753
+
754
+ if not return_dict:
755
+ return [loss, logits, outputs[1:]]
756
+
757
+ return {
758
+ "loss": loss,
759
+ "logits": logits,
760
+ "hidden_states": outputs.get("hidden_states", None),
761
+ }
762
+
763
+ def sanitize(self, weights):
764
+
765
+ sanitized_weights = _sanitize_backbone(weights)
766
+
767
+ # Specific adjustments for MLM
768
+ final_weights = {}
769
+ for k, v in sanitized_weights.items():
770
+ if not k.startswith("model.") and not k.startswith("head.") and not k.startswith("decoder."):
771
+ continue
772
+
773
+ # Handle Weight Tying for loading:
774
+ if k == "model.embed_tokens.weight" and "decoder.weight" not in weights:
775
+ final_weights["decoder.weight"] = v
776
+
777
+ final_weights[k] = v
778
+
779
+ return final_weights
780
+
781
+
782
+ class ModelForTokenClassification(RaclateBaseModel):
783
+ """
784
+ Computes token classification probabilities for input sequences.
785
+
786
+ NOTE: untested for now
787
+ """
788
+ def __init__(self, config: ModelArgs):
789
+ super().__init__()
790
+ self.config = config
791
+ if config.is_causal:
792
+ raise ValueError("ModelForTokenClassification requires bidirectional attention.")
793
+ self.num_labels = config.num_labels
794
+
795
+ self.model = T5GemmaEncoder(config)
796
+ self.score = T5GemmaClassificationHead(config)
797
+
798
+ self.hf_transformers_arch = "T5GemmaForTokenClassification"
799
+
800
+ def __call__(
801
+ self,
802
+ input_ids,
803
+ attention_mask: Optional[mx.array] = None,
804
+ position_ids: Optional[mx.array] = None,
805
+ labels: Optional[mx.array] = None,
806
+ output_hidden_states: Optional[bool] = None,
807
+ return_dict: Optional[bool] = True,
808
+ ) -> Dict:
809
+ if attention_mask is None:
810
+ batch_size, seq_len = input_ids.shape
811
+ attention_mask = mx.ones((batch_size, seq_len))
812
+
813
+ outputs = self.model(
814
+ input_ids=input_ids,
815
+ attention_mask=attention_mask,
816
+ position_ids=position_ids,
817
+ output_hidden_states=output_hidden_states,
818
+ return_dict=return_dict,
819
+ )
820
+
821
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
822
+
823
+ # Apply prediction head, dropout, and classification layer to each token
824
+
825
+ logits = self.score(last_hidden_state)
826
+
827
+ # Process logits for inference
828
+ processed_logits = mx.softmax(logits, axis=-1)
829
+
830
+ loss = None
831
+ if labels is not None:
832
+ # Compute token classification loss
833
+ loss = nn.losses.cross_entropy(
834
+ logits.reshape(-1, self.num_labels),
835
+ labels.reshape(-1)
836
+ )
837
+
838
+ if not return_dict:
839
+ return [loss, processed_logits, outputs[1:]]
840
+
841
+ return {
842
+ "loss": loss,
843
+ "probabilities": processed_logits,
844
+ "hidden_states": outputs.get("hidden_states", None),
845
+ }
846
+
847
+ def sanitize(self, weights):
848
+
849
+ sanitized = _sanitize_backbone(weights)
850
+
851
+ sanitized_weights = {}
852
+ for k, v in sanitized.items():
853
+ if not k.startswith("model.") and not k.startswith("score."):
854
+ continue
855
+ sanitized_weights[k] = v
856
+
857
+ return sanitized_weights