mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,913 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Any, Literal
3
+ import re
4
+ from functools import partial
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from .base import (
10
+ BaseModelArgs,
11
+ mean_pooling,
12
+ last_token_pooling,
13
+ normalize_embeddings,
14
+ compute_similarity_and_loss,
15
+ RaclateBaseModel,
16
+ )
17
+
18
+ @dataclass
19
+ class ModelArgs(BaseModelArgs):
20
+ architectures: List[str] = field(default_factory=lambda: ["Gemma3TextModel"])
21
+ attention_bias: Optional[bool] = False
22
+ attention_dropout: Optional[float] = 0.0
23
+ attn_logit_softcapping: Optional[float] = None # Not supported with sdpa
24
+ bos_token_id: Optional[int] = None
25
+ eos_token_id: Optional[int] = None
26
+ final_logit_softcapping: Optional[float] = None # Not supported with sdpa
27
+ hidden_activation: Optional[str] = "gelu_pytorch_tanh"
28
+ hidden_size: int = 1152
29
+ intermediate_size: int = 6912
30
+ initializer_range: Optional[float] = (
31
+ 0.02 # Only needed in case of initializing weights
32
+ )
33
+ head_dim: int = 256
34
+ layer_types: List[str] = field(default_factory=list)
35
+ max_position_embeddings: int = 2048
36
+ model_type: str = "gemma3_text"
37
+ num_attention_heads: int = 4
38
+ num_hidden_layers: int = 26
39
+ num_key_value_heads: int = 1
40
+ query_pre_attn_scalar: float = 256
41
+ rms_norm_eps: float = 1.0e-6
42
+ rope_global_base_freq: Optional[float] = None
43
+ rope_local_base_freq: float = 10000.0
44
+ rope_theta: Optional[float] = None
45
+ rope_traditional: bool = False
46
+ sliding_window: int = 512
47
+ _sliding_window_pattern: Optional[int] = None
48
+ sliding_window_pattern: Optional[int] = None
49
+ use_bidirectional_attn: bool = False
50
+ use_bidirectional_attention: bool = False
51
+ vocab_size: int = 262144
52
+
53
+ ### Defaults
54
+ default_sliding_pattern: int = 6
55
+ default_global_rope_freq: float = 1000000.0
56
+
57
+ ### pipeline args
58
+ decoder_bias=True,
59
+ classifier_dropout=0.0
60
+ classifier_bias=False
61
+ sparse_prediction=True ### True seems a more appropriate value for MLM
62
+ sparse_pred_ignore_index=-100
63
+ is_regression: Optional[bool] = None
64
+ label2id: Optional[Dict[str, int]] = None
65
+ id2label: Optional[Dict[int, str]] = None
66
+ pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
67
+ use_late_interaction: bool = False
68
+
69
+ @property
70
+ def sliding_pattern(self) -> int:
71
+ if self.sliding_window_pattern is not None:
72
+ return self.sliding_window_pattern
73
+ if self._sliding_window_pattern is not None:
74
+ return self._sliding_window_pattern
75
+ return self.default_sliding_pattern
76
+
77
+ @property
78
+ def rope_global_freq(self) -> float:
79
+ if self.rope_global_base_freq is not None:
80
+ return self.rope_global_base_freq
81
+ if self.rope_theta is not None:
82
+ return self.rope_theta
83
+ return self.default_global_rope_freq
84
+
85
+ @property
86
+ def is_causal(self) -> bool:
87
+ return not self.use_bidirectional_attn and not self.use_bidirectional_attention
88
+
89
+ @property
90
+ def num_labels(self) -> int:
91
+ """
92
+ Number of labels is determined by:
93
+ - For zero-shot classification: length of label_candidates
94
+ - For regression or binary with sigmoid: 1
95
+ - For classification: length of id2label mapping
96
+ """
97
+
98
+ if self.is_regression:
99
+ return 1
100
+
101
+ if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
102
+ return 1
103
+
104
+ if self.id2label is None:
105
+ raise ValueError(
106
+ "id2label mapping must be provided for categorical classification. "
107
+ "For regression or binary classification with sigmoid output, "
108
+ "set is_regression=True or binary_sigmoid=True in pipeline_config."
109
+ )
110
+
111
+ return len(self.id2label)
112
+
113
+ def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
114
+ """
115
+ Standardizes keys for the Gemma3 Backbone.
116
+ Prefixes generic keys with 'model.' and handles basic mapping.
117
+ """
118
+ sanitized = {}
119
+ for k, v in weights.items():
120
+ # Skip unrelated heads that might be in the checkpoint
121
+ if any(x in k for x in ["lm_head", "classifier"]):
122
+ # We don't automatically map these; specific models handle them if needed
123
+ continue
124
+
125
+ # Map generic 'layers' to 'model.layers' if not already present
126
+ if "Dense.linear" not in k and \
127
+ not k.startswith("model.") and \
128
+ not k.startswith("dense.") and \
129
+ not k.startswith("score.") and \
130
+ not k.startswith("head.") and \
131
+ not k.startswith("decoder."):
132
+
133
+ new_key = f"model.{k}"
134
+
135
+ sanitized[new_key] = v
136
+ else:
137
+ sanitized[k] = v
138
+
139
+ return sanitized
140
+
141
+ class Attention(nn.Module):
142
+ def __init__(self, args: ModelArgs, layer_idx: int):
143
+ super().__init__()
144
+
145
+ dim = args.hidden_size
146
+ self.n_heads = n_heads = args.num_attention_heads
147
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
148
+ self.repeats = n_heads // n_kv_heads
149
+ self.head_dim = head_dim = args.head_dim
150
+ self.layer_idx = layer_idx
151
+
152
+ self.scale = args.query_pre_attn_scalar**-0.5
153
+
154
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
155
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
156
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
157
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
158
+
159
+ self.q_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
160
+ self.k_norm = RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
161
+
162
+ layer_type = args.layer_types[layer_idx] if args.layer_types else None
163
+ if not layer_type:
164
+ if (layer_idx + 1) % args.sliding_pattern == 0:
165
+ layer_type = "full_attention"
166
+ else:
167
+ layer_type = "sliding_window"
168
+ self.is_sliding = layer_type == "sliding_window"
169
+
170
+ base = (
171
+ args.rope_local_base_freq
172
+ if self.is_sliding
173
+ else args.rope_global_freq
174
+ )
175
+
176
+ self.rope = nn.RoPE(
177
+ head_dim,
178
+ traditional=args.rope_traditional,
179
+ base=base,
180
+ )
181
+
182
+ # Add softcapping support
183
+ self.attn_logit_softcapping = args.attn_logit_softcapping
184
+
185
+
186
+ def __call__(
187
+ self,
188
+ x: mx.array,
189
+ mask: Optional[mx.array] = None
190
+ ) -> mx.array:
191
+ B, L, _ = x.shape
192
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
193
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
194
+
195
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
196
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
197
+
198
+ queries = self.q_norm(queries)
199
+ keys = self.k_norm(keys)
200
+
201
+ queries = self.rope(queries)
202
+ keys = self.rope(keys)
203
+
204
+ if self.attn_logit_softcapping is None:
205
+ output = mx.fast.scaled_dot_product_attention(
206
+ queries, keys, values, scale=self.scale, mask=mask
207
+ )
208
+ else:
209
+ raise NotImplementedError("Softcapping attention not supported with sdpa.")
210
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
211
+ return self.o_proj(output)
212
+
213
+ class RMSNorm(nn.Module):
214
+ def __init__(self, dims: int, eps: float = 1e-5):
215
+ super().__init__()
216
+ self.weight = mx.ones((dims,))
217
+ self.eps = eps
218
+
219
+ def __call__(self, x):
220
+ return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
221
+
222
+ class MLP(nn.Module):
223
+ def __init__(self, dim, hidden_dim):
224
+ super().__init__()
225
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
226
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
227
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
228
+
229
+ def __call__(self, x) -> mx.array:
230
+ return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
231
+
232
+ @partial(mx.compile, shapeless=True)
233
+ def clip_residual(x, y):
234
+ if x.dtype != mx.float16:
235
+ return x + y
236
+ bound = mx.finfo(mx.float16).max
237
+ return mx.clip(x.astype(mx.float32) + y.astype(mx.float32), -bound, bound).astype(
238
+ mx.float16
239
+ )
240
+
241
+ class TransformerBlock(nn.Module):
242
+ def __init__(self, args: ModelArgs, layer_idx: int):
243
+ super().__init__()
244
+ self.num_attention_heads = args.num_attention_heads
245
+ self.hidden_size = args.hidden_size
246
+ self.self_attn = Attention(args, layer_idx)
247
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
248
+ self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
249
+ self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
250
+ self.pre_feedforward_layernorm = RMSNorm(
251
+ args.hidden_size, eps=args.rms_norm_eps
252
+ )
253
+ self.post_feedforward_layernorm = RMSNorm(
254
+ args.hidden_size, eps=args.rms_norm_eps
255
+ )
256
+
257
+ def __call__(
258
+ self,
259
+ x: mx.array,
260
+ mask: Optional[mx.array] = None
261
+ ) -> mx.array:
262
+ r = self.self_attn(self.input_layernorm(x), mask)
263
+ h = clip_residual(x, self.post_attention_layernorm(r))
264
+ r = self.mlp(self.pre_feedforward_layernorm(h))
265
+ out = clip_residual(h, self.post_feedforward_layernorm(r))
266
+ return (out,)
267
+
268
+
269
+ class Gemma3Model(nn.Module):
270
+ def __init__(self, config: ModelArgs):
271
+ super().__init__()
272
+ self.config = config
273
+ self.vocab_size = config.vocab_size
274
+ self.num_hidden_layers = config.num_hidden_layers
275
+ assert self.vocab_size > 0
276
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
277
+ self.layers = [
278
+ TransformerBlock(config, layer_idx=i)
279
+ for i in range(config.num_hidden_layers)
280
+ ]
281
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+
283
+ def get_input_embeddings(self) -> nn.Embedding:
284
+ return self.embed_tokens
285
+
286
+ def set_input_embeddings(self, value):
287
+ self.embed_tokens = value
288
+
289
+ def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
290
+ """
291
+ Creates a causal mask and combines it with the padding mask.
292
+ """
293
+
294
+ B, L = attention_mask.shape
295
+ window_size = self.config.sliding_window
296
+ indices = mx.arange(L)
297
+ row = indices[:, None]
298
+ col = row.T
299
+
300
+ if not self.config.is_causal:
301
+ mask_base = mx.zeros((L, L), dtype=mx.bool_) # All False (visible)
302
+
303
+ # Sliding Window Logic for Bidirectional:
304
+ # Valid if abs(row - col) < window
305
+ # Mask if distance >= window
306
+ dist = mx.abs(row - col)
307
+ mask_window_violator = dist >= window_size
308
+
309
+ else:
310
+ # Causal: Standard triangular mask
311
+ mask_future = col > row
312
+ mask_base = mask_future
313
+
314
+ # Sliding Window Logic for Causal:
315
+ # Valid if row - col < window (and not future)
316
+ # Mask if (row - col) >= window
317
+ mask_past = (row - col) >= window_size
318
+ mask_window_violator = mask_past
319
+
320
+ global_mask = mx.where(mask_base, -1e9, 0.0).astype(dtype)
321
+ sliding_mask_bool = mask_base | mask_window_violator
322
+ sliding_mask = mx.where(sliding_mask_bool, -1e9, 0.0).astype(dtype)
323
+
324
+ # Padding Mask
325
+ if attention_mask is not None:
326
+ # Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
327
+ padding_mask = attention_mask[:, None, None, :]
328
+ additive_padding = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
329
+
330
+ global_mask = global_mask + additive_padding
331
+ sliding_mask = sliding_mask + additive_padding
332
+
333
+ return global_mask, sliding_mask
334
+
335
+ def __call__(
336
+ self,
337
+ input_ids: mx.array,
338
+ attention_mask: Optional[mx.array] = None,
339
+ output_hidden_states: Optional[bool] = False,
340
+ position_ids: Optional[mx.array] = None,
341
+ return_dict: Optional[bool] = True
342
+ ):
343
+ hidden_states = self.embed_tokens(input_ids)
344
+ model_dtype = hidden_states.dtype
345
+
346
+ # normalizer
347
+ hidden_states *= mx.array(self.config.hidden_size**0.5, model_dtype)
348
+
349
+ global_mask, sliding_window_mask = self._update_attention_mask(
350
+ attention_mask,
351
+ dtype=model_dtype
352
+ )
353
+
354
+ for i, layer in enumerate(self.layers):
355
+ if self.config.layer_types:
356
+ is_global = self.config.layer_types[i] == "full_attention"
357
+ else:
358
+ # Fallback to pattern
359
+ is_global = (i + 1) % self.config.sliding_pattern == 0
360
+ layer_mask = global_mask if is_global else sliding_window_mask
361
+ layer_outputs = layer(hidden_states, layer_mask)
362
+ hidden_states = layer_outputs[0]
363
+
364
+ hidden_states = self.norm(hidden_states)
365
+
366
+ return {
367
+ "last_hidden_state": hidden_states,
368
+ }
369
+
370
+
371
+ class Model(RaclateBaseModel):
372
+ def __init__(self, config: ModelArgs):
373
+ super().__init__()
374
+ self.config = config
375
+ self.model_type = config.model_type
376
+ self.model = Gemma3Model(config)
377
+ self.dense = [
378
+ nn.Linear(config.hidden_size, config.hidden_size * 4, bias=False),
379
+ nn.Linear(config.hidden_size * 4, config.hidden_size, bias=False),
380
+ ]
381
+
382
+ def __call__(
383
+ self,
384
+ input_ids: mx.array,
385
+ position_ids: Optional[mx.array] = None,
386
+ attention_mask: Optional[mx.array] = None,
387
+ output_hidden_states: Optional[bool] = False,
388
+ return_dict: Optional[bool] = True,
389
+ ):
390
+
391
+ if attention_mask is None:
392
+ batch_size, seq_len = input_ids.shape
393
+ attention_mask = mx.ones(
394
+ (batch_size, seq_len),
395
+ dtype=self.model.embed_tokens.weight.dtype,
396
+ )
397
+
398
+ out = self.model(input_ids, attention_mask)
399
+ last_hidden_state = (
400
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
401
+ )
402
+
403
+ # normalized features
404
+ if not self.config.is_causal:
405
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
406
+ else:
407
+ text_embeds = last_token_pooling(last_hidden_state, attention_mask)
408
+
409
+ for layer in self.dense:
410
+ text_embeds = layer(text_embeds)
411
+
412
+ text_embeds = normalize_embeddings(text_embeds)
413
+
414
+ if not return_dict:
415
+ return (text_embeds, last_hidden_state)
416
+
417
+ return {
418
+ "embeddings": text_embeds, # normalized embeddings
419
+ "last_hidden_state": last_hidden_state,
420
+ }
421
+
422
+ def sanitize(self, weights):
423
+
424
+ sanitized_weights = _sanitize_backbone(weights)
425
+
426
+ # Handle SentenceTransformer specific keys
427
+ final_weights = {}
428
+ for k, v in sanitized_weights.items():
429
+
430
+ if not k.startswith("model."):
431
+ continue
432
+ final_weights[k] = v
433
+
434
+ return final_weights
435
+
436
+
437
+ class ModelForSentenceSimilarity(RaclateBaseModel):
438
+ """
439
+ Computes similarity scores between input sequences and reference sentences.
440
+ """
441
+ def __init__(self, config: ModelArgs):
442
+ super().__init__()
443
+ self.config = config
444
+ self.model_type = config.model_type
445
+ self.model = Gemma3Model(config)
446
+ self.dense = [
447
+ nn.Linear(config.hidden_size, config.hidden_size * 4, bias=False),
448
+ nn.Linear(config.hidden_size * 4, config.hidden_size, bias=False),
449
+ ]
450
+
451
+ def _call_model(self, input_ids, attention_mask=None, return_dict=True):
452
+ out = self.model(input_ids, attention_mask)
453
+ last_hidden_state = (
454
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
455
+ )
456
+
457
+ # text_embeds = normalize_embeddings(last_hidden_state)
458
+ if self.config.use_late_interaction:
459
+ for dense in self.dense:
460
+ last_hidden_state = dense(last_hidden_state)
461
+
462
+ text_embeds = normalize_embeddings(last_hidden_state)
463
+ # Keep unpooled for ColBERT style
464
+ # Mask padding tokens to avoid them affecting MaxSim
465
+ if attention_mask is not None:
466
+ text_embeds = text_embeds * attention_mask[..., None]
467
+ else:
468
+ # Standard dense retrieval: Mean Pooling
469
+ if not self.config.is_causal:
470
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
471
+ else:
472
+ text_embeds = last_token_pooling(last_hidden_state, attention_mask)
473
+
474
+ for layer in self.dense:
475
+ text_embeds = layer(text_embeds)
476
+
477
+ text_embeds = normalize_embeddings(text_embeds)
478
+
479
+
480
+ if not return_dict:
481
+ return (text_embeds, last_hidden_state)
482
+
483
+ return {
484
+ "embeddings": text_embeds, # normalized embeddings
485
+ "last_hidden_state": last_hidden_state,
486
+ }
487
+
488
+ def __call__(
489
+ self,
490
+ input_ids,
491
+ reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
492
+ negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
493
+ attention_mask: Optional[mx.array] = None,
494
+ reference_attention_mask: Optional[mx.array] = None,
495
+ negative_attention_mask: Optional[mx.array] = None,
496
+ similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
497
+ position_ids: Optional[mx.array] = None,
498
+ return_dict: Optional[bool] = True,
499
+ ):
500
+ if attention_mask is None:
501
+ batch_size, seq_len = input_ids.shape
502
+ attention_mask = mx.ones(
503
+ (batch_size, seq_len),
504
+ dtype=self.model.embed_tokens.weight.dtype,
505
+ )
506
+
507
+ # Get embeddings for input batch
508
+ batch_outputs = self._call_model(
509
+ input_ids=input_ids,
510
+ attention_mask=attention_mask,
511
+ return_dict=True
512
+ )
513
+ embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
514
+
515
+ loss = None
516
+ similarities = None
517
+
518
+ if reference_input_ids is not None:
519
+
520
+ # Get embeddings for reference sentences
521
+ ref_outputs = self._call_model(
522
+ input_ids=reference_input_ids,
523
+ attention_mask=reference_attention_mask,
524
+ return_dict=True
525
+ )
526
+ reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
527
+
528
+ # Compute similarities and loss
529
+ similarities, loss = compute_similarity_and_loss(
530
+ self.config,
531
+ input_ids=input_ids,
532
+ embeddings=embeddings,
533
+ reference_embeddings=reference_embeddings,
534
+ call_model=self._call_model,
535
+ similarity_scores=similarity_scores,
536
+ negative_input_ids=negative_input_ids,
537
+ negative_attention_mask=negative_attention_mask
538
+ )
539
+
540
+ if not return_dict:
541
+ return (loss, similarities, embeddings)
542
+
543
+ return {
544
+ "loss": loss,
545
+ "similarities": similarities, # [batch_size, num_references]
546
+ "embeddings": embeddings, # [batch_size, hidden_size]
547
+ }
548
+
549
+ def sanitize(self, weights):
550
+
551
+ sanitized_weights = _sanitize_backbone(weights)
552
+
553
+ # Handle SentenceTransformer specific keys
554
+ final_weights = {}
555
+ for k, v in sanitized_weights.items():
556
+
557
+ if k.startswith("dense.") or k.startswith("model."):
558
+ final_weights[k] = v
559
+ elif re.search(r"\d+_Dense", k):
560
+ key_id = "0" if v.shape[0] > v.shape[1] else "1"
561
+ new_key = re.sub(r"\d+_Dense\.linear", f"dense.{key_id}", k)
562
+ final_weights[new_key] = v
563
+ else:
564
+ continue
565
+
566
+ return final_weights
567
+
568
+ class ModelForSentenceTransformers(ModelForSentenceSimilarity):
569
+ """
570
+ Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
571
+ This class sanitizes typical sentence transformers weights to align with the Gemma3 model.
572
+ """
573
+ def __init__(self, config: ModelArgs):
574
+ super().__init__(config)
575
+
576
+
577
+
578
+ class Gemma3PredictionHead(nn.Module):
579
+ def __init__(self, config: ModelArgs):
580
+ super().__init__()
581
+ self.config = config
582
+ self.dense = nn.Linear(
583
+ config.hidden_size, config.hidden_size, config.classifier_bias
584
+ )
585
+ self.act = nn.GELU(approx="precise")
586
+ self.norm = nn.RMSNorm(
587
+ config.hidden_size, eps=config.rms_norm_eps
588
+ )
589
+
590
+ self.soft_cap = config.final_logit_softcapping
591
+
592
+ def __call__(self, hidden_states: mx.array) -> mx.array:
593
+ logits = self.norm(self.act(self.dense(hidden_states)))
594
+ if self.soft_cap is not None:
595
+ logits = mx.tanh(logits / self.soft_cap) * self.soft_cap
596
+ return logits
597
+
598
+ class ModelForSequenceClassification(RaclateBaseModel):
599
+ """
600
+ Computes sequence classification probabilities for input sequences.
601
+ Sanitization aligns typical BERT weights with HF's Qwen3ForSequenceClassification architecture.
602
+
603
+ NOTE : regression and binary classification not tested.
604
+ """
605
+ def __init__(self, config: ModelArgs):
606
+ super().__init__()
607
+ self.config = config
608
+ self.num_labels = config.num_labels
609
+ self.is_regression = config.is_regression
610
+
611
+ self.model = Gemma3Model(config)
612
+
613
+ ### The HF architecture Gemma3ForSequenceClassification
614
+ ### does not have head and drop
615
+ #### and uses 'score' as the final layer name
616
+ # self.head = Gemma3PredictionHead(config)
617
+ # self.drop = nn.Dropout(p=config.classifier_dropout)
618
+
619
+ self.score = nn.Linear(
620
+ config.hidden_size,
621
+ config.num_labels,
622
+ bias=False
623
+ )
624
+
625
+ self.hf_transformers_arch = "Gemma3ForSequenceClassification"
626
+
627
+ def _process_outputs(self, logits: mx.array) -> mx.array:
628
+ """Apply the appropriate activation function to the logits."""
629
+ if self.is_regression:
630
+ return logits # No activation for regression
631
+ elif self.num_labels == 1:
632
+ return mx.sigmoid(logits) # Binary classification
633
+ else:
634
+ # Using softmax for multi-class classification
635
+ return mx.softmax(logits, axis=-1)
636
+
637
+ def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
638
+ """Compute the appropriate loss based on label characteristics."""
639
+ if self.is_regression:
640
+ return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
641
+ elif self.num_labels == 1:
642
+ return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
643
+ else:
644
+ return nn.losses.cross_entropy(
645
+ logits.reshape(-1, self.num_labels),
646
+ labels.reshape(-1)
647
+ )
648
+
649
+ def __call__(
650
+ self,
651
+ input_ids,
652
+ attention_mask: Optional[mx.array] = None,
653
+ position_ids: Optional[mx.array] = None, ### need this?
654
+ labels: Optional[mx.array] = None,
655
+ output_hidden_states: Optional[bool] = False,
656
+ return_dict: Optional[bool] = True,
657
+ ) -> Dict:
658
+ if attention_mask is None:
659
+ batch_size, seq_len = input_ids.shape
660
+ attention_mask = mx.ones(
661
+ (batch_size, seq_len),
662
+ dtype=self.model.embed_tokens.weight.dtype,
663
+ )
664
+
665
+ outputs = self.model(
666
+ input_ids,
667
+ attention_mask,
668
+ position_ids=position_ids,
669
+ output_hidden_states=output_hidden_states,
670
+ return_dict=return_dict
671
+ )
672
+
673
+ last_hidden_state = (
674
+ outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
675
+ )
676
+
677
+ # normalized features
678
+ if not self.config.is_causal:
679
+ text_embeds = mean_pooling(last_hidden_state, attention_mask)
680
+ else:
681
+ text_embeds = last_token_pooling(last_hidden_state, attention_mask)
682
+
683
+ ### The HF architecture Gemma3ForSequenceClassification
684
+ logits = self.score(text_embeds)
685
+
686
+ processed_logits = self._process_outputs(logits)
687
+
688
+ loss = None
689
+ if labels is not None :
690
+ loss = self._compute_loss(logits, labels)
691
+
692
+ if not return_dict:
693
+ return [loss, processed_logits, outputs[1:]]
694
+
695
+ return {
696
+ "loss": loss,
697
+ "probabilities": processed_logits,
698
+ "hidden_states": outputs.get("hidden_states", None),
699
+ }
700
+
701
+ def sanitize(self, weights):
702
+
703
+ sanitized_weights = _sanitize_backbone(weights)
704
+
705
+ # Filter out keys from 'embeddingsgemma3' that we don't want (dense projections)
706
+ final_weights = {}
707
+ for k, v in sanitized_weights.items():
708
+ if not k.startswith("model.") and not k.startswith("score."):
709
+ continue
710
+ final_weights[k] = v
711
+
712
+ return final_weights
713
+
714
+ class ModelForMaskedLM(RaclateBaseModel):
715
+ """
716
+ Computes masked language modeling (MLM) loss for input sequences.
717
+ """
718
+ def __init__(self, config : ModelArgs):
719
+ super().__init__()
720
+ self.config = config
721
+ if config.is_causal:
722
+ raise ValueError("ModelForMaskedLM requires bidirectional attention.")
723
+ self.model = Gemma3Model(config)
724
+ self.head = Gemma3PredictionHead(config)
725
+ self.decoder = nn.Linear(
726
+ config.hidden_size, config.vocab_size, bias=config.decoder_bias
727
+ )
728
+
729
+ # transformers has no MaskedLM class for Gemma3
730
+
731
+ # We explicitly call tie_weights to ensure logic is set up,
732
+ # though standard loading overwrites this unless sanitized correctly.
733
+ self.tie_weights()
734
+
735
+ def tie_weights(self):
736
+ self.decoder.weight = self.model.embed_tokens.weight
737
+
738
+ def get_input_embeddings(self):
739
+ return self.model.get_input_embeddings()
740
+
741
+ def get_output_embeddings(self):
742
+ return self.decoder
743
+
744
+ def set_input_embeddings(self, value):
745
+ self.model.set_input_embeddings(value)
746
+ self.tie_weights() # Re-tie weights after setting new embeddings
747
+
748
+ def set_output_embeddings(self, new_embeddings):
749
+ self.decoder = new_embeddings
750
+ self.tie_weights() # Re-tie weights after setting new decoder
751
+
752
+ def __call__(
753
+ self,
754
+ input_ids,
755
+ attention_mask: Optional[mx.array] = None,
756
+ labels: Optional[mx.array] = None,
757
+ position_ids: Optional[mx.array] = None,
758
+ output_hidden_states: Optional[bool] = None,
759
+ return_dict: Optional[bool] = True,
760
+ ) -> Dict:
761
+
762
+ if attention_mask is None:
763
+ batch_size, seq_len = input_ids.shape
764
+ attention_mask = mx.ones((batch_size, seq_len)) ### updated via _update_attention_mask() in the model
765
+
766
+ outputs = self.model(
767
+ input_ids=input_ids,
768
+ attention_mask=attention_mask,
769
+ position_ids=position_ids,
770
+ output_hidden_states=output_hidden_states,
771
+ return_dict=return_dict,
772
+ )
773
+
774
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
775
+ logits = self.head(last_hidden_state)
776
+ logits = self.decoder(logits)
777
+
778
+ loss = None
779
+ if self.training and labels is not None :
780
+ if getattr(self.config, "sparse_prediction", False):
781
+ # Flatten labels and predictions
782
+ flat_labels = labels.reshape(-1)
783
+ flat_predictions = logits.reshape(-1, logits.shape[-1])
784
+
785
+ # Filter out non-masked tokens
786
+ ignore_index = getattr(self.config, "sparse_pred_ignore_index", -100)
787
+ mask_tokens = flat_labels != ignore_index
788
+
789
+ # Only compute loss on masked tokens
790
+ masked_predictions = flat_predictions[mask_tokens]
791
+ masked_labels = flat_labels[mask_tokens]
792
+
793
+ loss = nn.losses.cross_entropy(
794
+ masked_predictions,
795
+ masked_labels,
796
+ reduction='mean'
797
+ )
798
+ else:
799
+ # Standard loss computation on all tokens
800
+ loss = nn.losses.cross_entropy(
801
+ logits.reshape(-1, logits.shape[-1]),
802
+ labels.reshape(-1),
803
+ reduction='mean'
804
+ )
805
+
806
+ if not return_dict:
807
+ return [loss, logits, outputs[1:]]
808
+
809
+ return {
810
+ "loss": loss,
811
+ "logits": logits,
812
+ "hidden_states": outputs.get("hidden_states", None),
813
+ }
814
+
815
+ def sanitize(self, weights):
816
+
817
+ sanitized_weights = _sanitize_backbone(weights)
818
+
819
+ # Specific adjustments for MLM
820
+ final_weights = {}
821
+ for k, v in sanitized_weights.items():
822
+ # Filter unwanted Dense layers from embedding checkpoints
823
+ # and/or score layers from classification checkpoints
824
+
825
+ if not k.startswith("model.") and \
826
+ not k.startswith("head.") and \
827
+ not k.startswith("decoder."):
828
+ continue
829
+
830
+ # Handle Weight Tying for loading:
831
+ if k == "model.embed_tokens.weight" and "decoder.weight" not in weights:
832
+ final_weights["decoder.weight"] = v
833
+
834
+ final_weights[k] = v
835
+
836
+ return final_weights
837
+
838
+
839
+ class ModelForTokenClassification(RaclateBaseModel):
840
+ """
841
+ Computes token classification probabilities for input sequences.
842
+
843
+ NOTE: untested for now
844
+ """
845
+ def __init__(self, config: ModelArgs):
846
+ super().__init__()
847
+ self.config = config
848
+ if config.is_causal:
849
+ raise ValueError("ModelForTokenClassification requires bidirectional attention.")
850
+ self.num_labels = config.num_labels
851
+
852
+ self.model = Gemma3Model(config)
853
+ self.drop = nn.Dropout(p=config.classifier_dropout)
854
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
855
+
856
+ # transformers does not have TokenClassification class for Gemma3
857
+
858
+ def __call__(
859
+ self,
860
+ input_ids,
861
+ attention_mask: Optional[mx.array] = None,
862
+ position_ids: Optional[mx.array] = None,
863
+ labels: Optional[mx.array] = None,
864
+ output_hidden_states: Optional[bool] = None,
865
+ return_dict: Optional[bool] = True,
866
+ ) -> Dict:
867
+ if attention_mask is None:
868
+ batch_size, seq_len = input_ids.shape
869
+ attention_mask = mx.ones((batch_size, seq_len))
870
+
871
+ outputs = self.model(
872
+ input_ids=input_ids,
873
+ attention_mask=attention_mask,
874
+ position_ids=position_ids,
875
+ output_hidden_states=output_hidden_states,
876
+ return_dict=return_dict,
877
+ )
878
+
879
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
880
+
881
+ # Apply prediction head, dropout, and classification layer to each token
882
+ logits = self.score(last_hidden_state)
883
+
884
+ # Process logits for inference
885
+ processed_logits = mx.softmax(logits, axis=-1)
886
+
887
+ loss = None
888
+ if labels is not None:
889
+ # Compute token classification loss
890
+ loss = nn.losses.cross_entropy(
891
+ logits.reshape(-1, self.num_labels),
892
+ labels.reshape(-1)
893
+ )
894
+
895
+ if not return_dict:
896
+ return [loss, processed_logits, outputs[1:]]
897
+
898
+ return {
899
+ "loss": loss,
900
+ "probabilities": processed_logits,
901
+ "hidden_states": outputs.get("hidden_states", None),
902
+ }
903
+
904
+ def sanitize(self, weights):
905
+ sanitized_weights = _sanitize_backbone(weights)
906
+
907
+ final_weights = {}
908
+ for k, v in sanitized_weights.items():
909
+ if not k.startswith("model.") and not k.startswith("score."):
910
+ continue
911
+ final_weights[k] = v
912
+
913
+ return final_weights