mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,900 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, Dict, Literal, Any, List
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+
7
+ from .base import (
8
+ BaseModelArgs,
9
+ RaclateBaseModel,
10
+ compute_similarity_and_loss,
11
+ mean_pooling,
12
+ normalize_embeddings
13
+ )
14
+
15
+ """ NOTE : This implementation of ModernBERT excludes all features related to Flash Attention 2, padded/unpadded handling"""
16
+
17
+ @dataclass
18
+ class ModelArgs(BaseModelArgs):
19
+ architectures: List[str] = field(default_factory=lambda: ["ModernBertModel"])
20
+ attention_bias: bool = False
21
+ attention_dropout : float =0.0
22
+ bos_token_id: int = 50281
23
+ cls_token_id: int = 50281
24
+ embedding_dropout : float = 0.0
25
+ eos_token_id : int = 50282
26
+ global_attn_every_n_layers : int = 3
27
+ global_rope_theta : float = 160000.0
28
+ hidden_size: int = 768
29
+ initializer_range : float = 0.02
30
+ initializer_cutoff_factor: float = 2.0 # relevant for MLX?
31
+ intermediate_size: int = 1152
32
+ local_attention : int =128
33
+ local_rope_theta: float = 10000
34
+ max_position_embeddings: int = 8192
35
+ mlp_bias: bool = False
36
+ mlp_dropout : float = 0.0
37
+ model_type: str = "modernbert"
38
+ norm_bias : bool = False
39
+ norm_eps: float = 1e-05
40
+ num_attention_heads: int = 12
41
+ num_hidden_layers: int = 22
42
+ output_hidden_states: bool = False
43
+ pad_token_id: int = 50283
44
+ sep_token_id: int = 50282
45
+ vocab_size: int = 50368
46
+
47
+ ### pipeline args
48
+ decoder_bias=True,
49
+ classifier_pooling: Literal["cls", "mean"] = "cls"
50
+ classifier_dropout=0.0
51
+ classifier_bias=False
52
+ sparse_prediction=True ### True seems a more appropriate value for MLM
53
+ sparse_pred_ignore_index=-100
54
+ is_regression: Optional[bool] = None
55
+ label2id: Optional[Dict[str, int]] = None
56
+ id2label: Optional[Dict[int, str]] = None
57
+ pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
58
+ use_late_interaction: bool = False
59
+
60
+ @property
61
+ def num_labels(self) -> int:
62
+ """
63
+ Number of labels is determined by:
64
+ - For zero-shot classification: length of label_candidates
65
+ - For regression or binary with sigmoid: 1
66
+ - For classification: length of id2label mapping
67
+ """
68
+
69
+ if self.is_regression:
70
+ return 1
71
+
72
+ if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
73
+ return 1
74
+
75
+ if self.id2label is None:
76
+ raise ValueError(
77
+ "id2label mapping must be provided for categorical classification. "
78
+ "For regression or binary classification with sigmoid output, "
79
+ "set is_regression=True or binary_sigmoid=True in pipeline_config."
80
+ )
81
+
82
+ return len(self.id2label)
83
+
84
+
85
+ class ModernBertEmbeddings(nn.Module):
86
+ """
87
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
88
+ """
89
+ def __init__(self, config: ModelArgs):
90
+ super().__init__()
91
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
92
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
93
+ self.drop = nn.Dropout(p=config.embedding_dropout)
94
+
95
+ def __call__(self, input_ids):
96
+ embeddings = self.tok_embeddings(input_ids)
97
+ embeddings = self.norm(embeddings)
98
+ embeddings = self.drop(embeddings)
99
+ return embeddings
100
+
101
+
102
+ class ModernBertMLP(nn.Module):
103
+ """Applies the GLU at the end of each ModernBERT layer.
104
+
105
+ Compared to the default BERT architecture, this block replaces class BertIntermediate`
106
+ and class SelfOutput with a single module that has similar functionality.
107
+ """
108
+ def __init__(self, config: ModelArgs):
109
+ super().__init__()
110
+ self.config = config
111
+ self.Wi = nn.Linear(config.hidden_size, config.intermediate_size *2, bias=config.mlp_bias)
112
+ self.act = nn.GELU()
113
+ self.drop = nn.Dropout(p=config.mlp_dropout)
114
+ self.Wo = nn.Linear(int(config.intermediate_size), config.hidden_size, bias=config.mlp_bias)
115
+
116
+ def __call__(self, hidden_states):
117
+ x = self.Wi(hidden_states)
118
+
119
+ split_dim = x.shape[-1] // 2
120
+ input, gate = x[:, :, :split_dim], x[:, :, split_dim:] # gate : https://arxiv.org/pdf/2002.05202v1
121
+ return self.Wo(self.drop(self.act(input) * gate))
122
+
123
+
124
+ class ModernBertAttention(nn.Module):
125
+ """Performs multi-headed self attention on a batch of unpadded sequences.
126
+ For now, only supports the Scaled Dot-Product Attention (SDPA) implementation.
127
+ """
128
+ def __init__(self, config: ModelArgs, layer_id: Optional[int] = None):
129
+ super().__init__()
130
+ self.config = config
131
+ self.layer_id = layer_id
132
+
133
+ if config.hidden_size % config.num_attention_heads != 0:
134
+ raise ValueError(
135
+ f"hidden_size ({config.hidden_size}) must be divisible by num_attention_heads ({config.num_attention_heads})"
136
+ )
137
+
138
+ self.attention_dropout = config.attention_dropout
139
+ self.num_heads = config.num_attention_heads
140
+ self.head_dim = config.hidden_size // config.num_attention_heads
141
+ self.all_head_size = self.head_dim * self.num_heads
142
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
143
+
144
+ if layer_id % config.global_attn_every_n_layers != 0:
145
+ self.local_attention = (config.local_attention // 2, config.local_attention // 2)
146
+ else:
147
+ self.local_attention = (-1, -1)
148
+
149
+ rope_theta = config.global_rope_theta
150
+ if self.local_attention != (-1, -1) and config.local_rope_theta is not None:
151
+ rope_theta = config.local_rope_theta
152
+
153
+ self.rotary_emb = nn.RoPE(dims=self.head_dim, base=rope_theta)
154
+
155
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
156
+ self.out_drop = nn.Dropout(p=config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
157
+ self.pruned_heads = set()
158
+
159
+ def __call__(
160
+ self,
161
+ hidden_states,
162
+ attention_mask = None,
163
+ sliding_window_mask = None,
164
+ **kwargs
165
+ ):
166
+ qkv = self.Wqkv(hidden_states)
167
+ bs = hidden_states.shape[0]
168
+ qkv = mx.reshape(qkv, (bs, -1, 3, self.num_heads, self.head_dim))
169
+
170
+ # Get attention outputs using SDPA
171
+ qkv = mx.transpose(
172
+ qkv, [0, 3, 2, 1, 4]
173
+ ) # [batch_size, nheads, 3, seqlen, headdim]
174
+ query, key, value = mx.split(
175
+ qkv, indices_or_sections=3, axis=2
176
+ ) # each [batch_size, nheads, 1, seqlen, headdim]
177
+ query = query.squeeze(2) # [batch_size, nheads, seqlen, headdim]
178
+ key = key.squeeze(2) # [batch_size, nheads, seqlen, headdim]
179
+ value = value.squeeze(2) # [batch_size, nheads, seqlen, headdim]
180
+
181
+ # Applying rotary embeddings
182
+ query = self.rotary_emb(query)
183
+ key = self.rotary_emb(key)
184
+
185
+ # Handling local attention if needed
186
+ if self.local_attention != (-1, -1):
187
+ attention_mask = sliding_window_mask
188
+
189
+ # Computing attention using MLX's SDPA
190
+ scale = query.shape[-1] ** -0.5
191
+ attn_output = mx.fast.scaled_dot_product_attention(
192
+ query, key, value,
193
+ scale=scale,
194
+ mask=attention_mask
195
+ )
196
+
197
+ # Reshaping and apply output projection
198
+ attn_output = mx.transpose(attn_output, [0, 2, 1, 3])
199
+ attn_output = mx.reshape(attn_output, (bs, -1, self.all_head_size))
200
+
201
+ # Applying output projection and dropout
202
+ hidden_states = self.Wo(attn_output)
203
+ hidden_states = self.out_drop(hidden_states)
204
+
205
+ return (hidden_states,)
206
+
207
+
208
+ class ModernBertEncoderLayer(nn.Module):
209
+ def __init__(self, config: ModelArgs, layer_id: Optional[int] = None):
210
+ super().__init__()
211
+ self.config = config
212
+ if layer_id == 0:
213
+ self.attn_norm = nn.Identity()
214
+ else:
215
+ self.attn_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
216
+ self.attn = ModernBertAttention(config=config, layer_id=layer_id)
217
+ self.mlp = ModernBertMLP(config)
218
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
219
+
220
+ def __call__(
221
+ self,
222
+ hidden_states ,
223
+ attention_mask =None,
224
+ sliding_window_mask = None,
225
+ position_ids = None,
226
+ ):
227
+ normalized_hidden_states = self.attn_norm(hidden_states)
228
+ attention_output = self.attn(
229
+ normalized_hidden_states,
230
+ attention_mask=attention_mask,
231
+ sliding_window_mask=sliding_window_mask,
232
+ position_ids=position_ids,
233
+ )
234
+ hidden_states = hidden_states + attention_output[0]
235
+ mlp_output = self.mlp(self.mlp_norm(hidden_states))
236
+ hidden_states = hidden_states + mlp_output
237
+
238
+ return (hidden_states,)
239
+
240
+
241
+ class ModernBertModel(nn.Module):
242
+ def __init__(self, config: ModelArgs):
243
+ super().__init__()
244
+ self.config = config
245
+ self.embeddings = ModernBertEmbeddings(config)
246
+ self.layers = [
247
+ ModernBertEncoderLayer(config, i) for i in range(config.num_hidden_layers)
248
+ ]
249
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
250
+ self.gradient_checkpointing = False
251
+
252
+ def get_input_embeddings(self) -> ModernBertEmbeddings:
253
+ return self.embeddings.tok_embeddings
254
+
255
+ def set_input_embeddings(self, value):
256
+ self.embeddings.tok_embeddings = value
257
+
258
+ def _update_attention_mask(self, attention_mask, model_dtype): #TODO: move to base.py ??
259
+
260
+ batch_size, seq_len = attention_mask.shape
261
+ neg_inf = -1e4
262
+
263
+ additive_mask = mx.where(attention_mask == 1, 0.0, neg_inf)
264
+ additive_mask = additive_mask[:, None, None, :] # (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
265
+
266
+ # Create the causal mask for global attention
267
+ global_attention_mask = mx.broadcast_to(additive_mask, (batch_size, 1, seq_len, seq_len))
268
+
269
+ # Create position indices for sliding window
270
+ rows = mx.arange(seq_len)
271
+ rows = rows[None, :] # (1, seq_len)
272
+ # Calculate position-wise distances
273
+ distance = mx.abs(rows - rows.T) # (seq_len, seq_len)
274
+
275
+ # Create sliding window mask using mx.where
276
+ window_mask = mx.where(
277
+ distance <= (self.config.local_attention // 2),
278
+ mx.ones_like(distance),
279
+ mx.zeros_like(distance)
280
+ )
281
+
282
+ # Expand dimensions using None indexing
283
+ window_mask = window_mask[None, None, :, :] # (1, 1, seq_len, seq_len)
284
+
285
+ # Broadcast to match batch size
286
+ window_mask = mx.broadcast_to(window_mask, global_attention_mask.shape)
287
+
288
+ # Create sliding window attention mask
289
+ # Replace non-window positions with large negative value
290
+ sliding_window_mask = mx.where(
291
+ window_mask,
292
+ global_attention_mask,
293
+ neg_inf # if not broadcasted for some reason : neg_inf * mx.ones_like(global_attention_mask)
294
+ )
295
+
296
+ # Convert to model_dtype for scaled_dot_product_attention
297
+ global_attention_mask = global_attention_mask.astype(model_dtype)
298
+ sliding_window_mask = sliding_window_mask.astype(model_dtype)
299
+
300
+ return global_attention_mask, sliding_window_mask
301
+
302
+ def __call__(
303
+ self,
304
+ input_ids,
305
+ attention_mask = None, # (batch_size, seq_len) see below
306
+ sliding_window_mask = None,
307
+ position_ids = None,
308
+ output_hidden_states = False,
309
+ return_dict = True,
310
+ ):
311
+ output_hidden_states = (
312
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
313
+ )
314
+
315
+ batch_size, seq_len = input_ids.shape[:2]
316
+
317
+ if attention_mask is None:
318
+ attention_mask = mx.ones((batch_size, seq_len)) ### updated with _update_attention_mask() below
319
+
320
+ hidden_states = self.embeddings(input_ids)
321
+ model_dtype = hidden_states.dtype
322
+
323
+ # get attention mask and sliding window mask
324
+ attention_mask, sliding_window_mask = self._update_attention_mask(
325
+ attention_mask=attention_mask,
326
+ model_dtype=model_dtype
327
+ )
328
+
329
+ all_hidden_states = () if output_hidden_states else None
330
+
331
+ for encoder_layer in self.layers:
332
+ if output_hidden_states:
333
+ all_hidden_states = all_hidden_states + (hidden_states,)
334
+
335
+ layer_outputs = encoder_layer(
336
+ hidden_states,
337
+ attention_mask=attention_mask,
338
+ sliding_window_mask=sliding_window_mask,
339
+ position_ids=position_ids,
340
+ )
341
+
342
+ hidden_states = layer_outputs[0]
343
+
344
+ hidden_states = self.final_norm(hidden_states)
345
+
346
+ if not return_dict:
347
+ return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
348
+ return {
349
+ "last_hidden_state": hidden_states,
350
+ "hidden_states": all_hidden_states,
351
+ }
352
+
353
+
354
+ ### below are the classes for specific pipelines
355
+ class Model(RaclateBaseModel):
356
+ """
357
+ Computes embeddings for input sequences using a ModernBERT model.
358
+
359
+ Note : sanitization is a hack to align with other models here while downloading weights
360
+ with the maskedlm config from HF (original modelBert model).
361
+ """
362
+ def __init__(self, config: ModelArgs):
363
+ super().__init__()
364
+ self.config = config
365
+ self.model = ModernBertModel(config)
366
+
367
+ # no transformer architecture for embedding model
368
+
369
+ def __call__(
370
+ self,
371
+ input_ids : mx.array,
372
+ attention_mask: Optional[mx.array] = None,
373
+ position_ids: Optional[mx.array] = None,
374
+ output_hidden_states: Optional[bool] = None,
375
+ return_dict: Optional[bool] = True,
376
+ ):
377
+
378
+ if attention_mask is None:
379
+ batch_size, seq_len = input_ids.shape
380
+ attention_mask = mx.ones(
381
+ (batch_size, seq_len),
382
+ dtype=self.model.embeddings.tok_embeddings.weight.dtype) ### updated via _update_attention_mask() in the model
383
+
384
+ # Get embeddings and encoder outputs as before
385
+ encoder_outputs = self.model(
386
+ input_ids,
387
+ attention_mask=attention_mask,
388
+ position_ids=position_ids,
389
+ output_hidden_states=output_hidden_states,
390
+ return_dict=return_dict,
391
+ )
392
+ last_hidden_state = encoder_outputs["last_hidden_state"] if isinstance(encoder_outputs, dict) else encoder_outputs[0]
393
+
394
+ # Pooling based on config
395
+ if self.config.classifier_pooling == "cls":
396
+ pooled = last_hidden_state[:, 0]
397
+ elif self.config.classifier_pooling == "mean":
398
+ pooled = mean_pooling(last_hidden_state, attention_mask)
399
+
400
+ text_embeds = normalize_embeddings(pooled)
401
+
402
+ if not return_dict:
403
+ return (text_embeds, last_hidden_state)
404
+
405
+ return {
406
+ "embeddings": text_embeds, # normalized embeddings
407
+ "last_hidden_state": last_hidden_state,
408
+ }
409
+
410
+ def sanitize(self, weights):
411
+ sanitized_weights = {}
412
+ for k, v in weights.items():
413
+ if "position_ids" in k:
414
+ # Remove unused position_ids
415
+ continue
416
+ if k in ["head.norm.weight", "head.dense.weight", "decoder.bias"]:
417
+ continue
418
+ else:
419
+ sanitized_weights[k] = v
420
+ return sanitized_weights
421
+
422
+
423
+ class ModelForSentenceSimilarity(RaclateBaseModel):
424
+ """
425
+ Handles:
426
+ 1. Inference: Generates embeddings and similarity scores (cosine similarity or MaxSim if late interaction is used).
427
+ 2. Training (Standard): (Sentence1, Sentence2, Score) -> MSE/Cosine Loss.
428
+ 3. Training (Triplets): (Anchor, Positive, Negative) -> MNRL with Hard Negatives (Cross-entropy Loss).
429
+ """
430
+ def __init__(self, config : ModelArgs):
431
+ super().__init__()
432
+ self.config = config
433
+ self.model = ModernBertModel(config)
434
+
435
+ def _call_model(
436
+ self,
437
+ input_ids: mx.array,
438
+ position_ids: Optional[mx.array] = None,
439
+ attention_mask: Optional[mx.array] = None,
440
+ output_hidden_states: Optional[bool] = False,
441
+ return_dict: Optional[bool] = True,
442
+ ):
443
+ out = self.model(input_ids, attention_mask)
444
+ last_hidden_state = (
445
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
446
+ )
447
+
448
+ # text_embeds = normalize_embeddings(last_hidden_state)
449
+ if self.config.use_late_interaction:
450
+ text_embeds = normalize_embeddings(last_hidden_state)
451
+ # Keep unpooled for ColBERT style
452
+ # Mask padding tokens to avoid them affecting MaxSim
453
+ if attention_mask is not None:
454
+ text_embeds = text_embeds * attention_mask[..., None]
455
+ else:
456
+ # Pooling based on config
457
+ if self.config.classifier_pooling == "cls":
458
+ pooled = last_hidden_state[:, 0]
459
+ elif self.config.classifier_pooling == "mean":
460
+ pooled = mean_pooling(last_hidden_state, attention_mask)
461
+ text_embeds = normalize_embeddings(pooled)
462
+
463
+ if not return_dict:
464
+ return (text_embeds, last_hidden_state)
465
+
466
+ return {
467
+ "embeddings": text_embeds, # normalized embeddings
468
+ "last_hidden_state": last_hidden_state,
469
+ }
470
+
471
+ def __call__(
472
+ self,
473
+ input_ids,
474
+ reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
475
+ negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
476
+ attention_mask: Optional[mx.array] = None,
477
+ reference_attention_mask: Optional[mx.array] = None,
478
+ negative_attention_mask: Optional[mx.array] = None,
479
+ similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
480
+ position_ids: Optional[mx.array] = None,
481
+ return_dict: Optional[bool] = True,
482
+ ):
483
+
484
+ if attention_mask is None:
485
+ batch_size, seq_len = input_ids.shape
486
+ attention_mask = mx.ones(
487
+ (batch_size, seq_len),
488
+ dtype=self.model.embeddings.tok_embeddings.weight.dtype) ### updated via _update_attention_mask() in the model
489
+
490
+ # Get embeddings for input batch
491
+ batch_outputs = self._call_model(
492
+ input_ids=input_ids,
493
+ attention_mask=attention_mask,
494
+ position_ids=position_ids,
495
+ return_dict=True
496
+ )
497
+ embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
498
+
499
+ loss = None
500
+ similarities = None
501
+ if reference_input_ids is not None:
502
+
503
+ # Get embeddings for reference sentences
504
+ ref_outputs = self._call_model(
505
+ input_ids=reference_input_ids,
506
+ attention_mask=reference_attention_mask,
507
+ position_ids=position_ids, ### ?
508
+ return_dict=True
509
+ )
510
+ reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
511
+
512
+ similarities, loss = compute_similarity_and_loss(
513
+ self.config,
514
+ input_ids,
515
+ embeddings,
516
+ reference_embeddings,
517
+ self._call_model,
518
+ similarity_scores,
519
+ negative_input_ids,
520
+ negative_attention_mask,
521
+ )
522
+
523
+ if not return_dict:
524
+ return (loss, similarities, embeddings)
525
+
526
+ return {
527
+ "loss": loss,
528
+ "similarities": similarities, # [batch_size, num_references]
529
+ "embeddings": embeddings, # [batch_size, hidden_size]
530
+ }
531
+
532
+ def sanitize(self, weights):
533
+ sanitized_weights = {}
534
+ for k, v in weights.items():
535
+ if "position_ids" in k:
536
+ # Remove unused position_ids
537
+ continue
538
+ if not k.startswith("model."):
539
+ continue
540
+ else:
541
+ sanitized_weights[k] = v
542
+ return sanitized_weights
543
+
544
+ class ModelForSentenceTransformers(ModelForSentenceSimilarity):
545
+ """
546
+ Extends ModelForSentenceSimilarity.
547
+ Handles:
548
+ 1. Inference: Generates embeddings and similarity scores (cosine similarity or MaxSim if late interaction is used).
549
+ 2. Training (Standard): (Sentence1, Sentence2, Score) -> MSE/Cosine Loss.
550
+ 3. Training (Triplets): (Anchor, Positive, Negative) -> MNRL with Hard Negatives (Cross-entropy Loss).
551
+ This class sanitizes typical sentence transformers weights to align with the ModernBERT model.
552
+ """
553
+ def __init__(self, config: ModelArgs):
554
+ super().__init__(config)
555
+
556
+ def sanitize(self, weights):
557
+ """Convert sentence transformer weights to ModernBERT format."""
558
+ sanitized_weights = {}
559
+
560
+ for k, v in weights.items():
561
+ if "position_ids" in k:
562
+ # Remove unused position_ids
563
+ continue
564
+ if not k.startswith("model."):
565
+ new_key = "model." + k
566
+ else:
567
+ new_key = k
568
+ sanitized_weights[new_key] = v
569
+ return sanitized_weights
570
+
571
+
572
+ class ModernBertPredictionHead(nn.Module):
573
+ def __init__(self, config : ModelArgs):
574
+ super().__init__()
575
+ self.dense = nn.Linear(
576
+ config.hidden_size, config.hidden_size, bias=False
577
+ ) ### current HF checkpoint does not have bias for the dense layer
578
+ self.act = nn.GELU()
579
+ self.norm = nn.LayerNorm(
580
+ config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
581
+ )
582
+
583
+ def __call__(self, hidden_states):
584
+ return self.norm(self.act(self.dense(hidden_states)))
585
+
586
+
587
+ class ModelForMaskedLM(RaclateBaseModel):
588
+ """
589
+ Computes masked language modeling (MLM) loss for input sequences.
590
+ """
591
+ def __init__(self, config : ModelArgs):
592
+ super().__init__()
593
+ self.config = config
594
+ self.model = ModernBertModel(config)
595
+ self.head = ModernBertPredictionHead(config) ## no bias for this in the current HF checkpoint
596
+ self.decoder = nn.Linear(
597
+ config.hidden_size, config.vocab_size, bias=config.decoder_bias
598
+ )
599
+
600
+ # transformer architecture name for compatibility
601
+ self.hf_transformers_arch = "ModernBertForMaskedLM"
602
+
603
+ # Tie weights ### does not seem to work (sanitizing the weights to enforce weight tying)
604
+ self.tie_weights()
605
+
606
+ def tie_weights(self):
607
+ embedding_layer = self.model.get_input_embeddings()
608
+ self.decoder.weight = embedding_layer.weight
609
+
610
+ def get_input_embeddings(self):
611
+ return self.model.get_input_embeddings()
612
+
613
+ def get_output_embeddings(self):
614
+ return self.decoder
615
+
616
+ def set_input_embeddings(self, value):
617
+ self.model.set_input_embeddings(value)
618
+ self.tie_weights() # Re-tie weights after setting new embeddings
619
+
620
+ def set_output_embeddings(self, new_embeddings):
621
+ self.decoder = new_embeddings
622
+ self.tie_weights() # Re-tie weights after setting new decoder
623
+
624
+ def __call__(
625
+ self,
626
+ input_ids,
627
+ attention_mask: Optional[mx.array] = None,
628
+ labels: Optional[mx.array] = None,
629
+ position_ids: Optional[mx.array] = None,
630
+ output_hidden_states: Optional[bool] = None,
631
+ return_dict: Optional[bool] = True,
632
+ ) -> Dict:
633
+
634
+ if attention_mask is None:
635
+ batch_size, seq_len = input_ids.shape
636
+ attention_mask = mx.ones((batch_size, seq_len)) ### updated via _update_attention_mask() in the model
637
+
638
+ outputs = self.model(
639
+ input_ids=input_ids,
640
+ attention_mask=attention_mask,
641
+ position_ids=position_ids,
642
+ output_hidden_states=output_hidden_states,
643
+ return_dict=return_dict,
644
+ )
645
+
646
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
647
+ logits = self.head(last_hidden_state)
648
+ logits = self.decoder(logits)
649
+
650
+ loss = None
651
+ if self.training and labels is not None :
652
+ if getattr(self.config, "sparse_prediction", False):
653
+ # Flatten labels and predictions
654
+ flat_labels = labels.reshape(-1)
655
+ flat_predictions = logits.reshape(-1, logits.shape[-1])
656
+
657
+ # Filter out non-masked tokens
658
+ ignore_index = getattr(self.config, "sparse_pred_ignore_index", -100)
659
+ mask_tokens = flat_labels != ignore_index
660
+
661
+ # Only compute loss on masked tokens
662
+ masked_predictions = flat_predictions[mask_tokens]
663
+ masked_labels = flat_labels[mask_tokens]
664
+
665
+ loss = nn.losses.cross_entropy(
666
+ masked_predictions,
667
+ masked_labels,
668
+ reduction='mean'
669
+ )
670
+ else:
671
+ # Standard loss computation on all tokens
672
+ loss = nn.losses.cross_entropy(
673
+ logits.reshape(-1, logits.shape[-1]),
674
+ labels.reshape(-1),
675
+ reduction='mean'
676
+ )
677
+
678
+ if not return_dict:
679
+ return [loss, logits, outputs[1:]]
680
+
681
+ return {
682
+ "loss": loss,
683
+ "logits": logits,
684
+ "hidden_states": outputs.get("hidden_states", None),
685
+ }
686
+
687
+ def sanitize(self, weights):
688
+ sanitized_weights = {}
689
+ for k, v in weights.items():
690
+ if "position_ids" in k:
691
+ # Remove unused position_ids
692
+ continue
693
+ if k == "model.embeddings.tok_embeddings.weight":
694
+ ### going around the weight tying issue. TODO : improve this
695
+ sanitized_weights["decoder.weight"] = v
696
+ sanitized_weights[k] = v
697
+ else:
698
+ sanitized_weights[k] = v
699
+ return sanitized_weights
700
+
701
+
702
+ class ModelForSequenceClassification(RaclateBaseModel):
703
+ """
704
+ Computes sequence classification probabilities for input sequences.
705
+ Sanitization aligns typical BERT weights with the ModernBERT model.
706
+
707
+ NOTE : binary classification not tested.
708
+ """
709
+ def __init__(self, config: ModelArgs):
710
+ super().__init__()
711
+ self.config = config
712
+ self.num_labels = config.num_labels
713
+ self.is_regression = config.is_regression
714
+
715
+ self.model = ModernBertModel(config)
716
+ self.head = ModernBertPredictionHead(config)
717
+ self.drop = nn.Dropout(p=config.classifier_dropout)
718
+ self.classifier = nn.Linear(
719
+ config.hidden_size,
720
+ config.num_labels,
721
+ )
722
+
723
+ # transformer architecture name for compatibility
724
+ self.hf_transformers_arch = "ModernBertForSequenceClassification"
725
+
726
+ def _process_outputs(self, logits: mx.array) -> mx.array:
727
+ """Apply the appropriate activation function to the logits."""
728
+ if self.is_regression:
729
+ return logits # No activation for regression
730
+ elif self.num_labels == 1:
731
+ return mx.sigmoid(logits) # Binary classification
732
+ else:
733
+ # Using softmax for multi-class classification
734
+ return mx.softmax(logits, axis=-1)
735
+
736
+ def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
737
+ """Compute the appropriate loss based on label characteristics."""
738
+ if self.is_regression:
739
+ return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
740
+ elif self.num_labels == 1:
741
+ return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
742
+ else:
743
+ return nn.losses.cross_entropy(
744
+ logits.reshape(-1, self.num_labels),
745
+ labels.reshape(-1)
746
+ )
747
+
748
+ def __call__(
749
+ self,
750
+ input_ids,
751
+ attention_mask: Optional[mx.array] = None,
752
+ position_ids: Optional[mx.array] = None, ### need this?
753
+ labels: Optional[mx.array] = None,
754
+ output_hidden_states: Optional[bool] = False,
755
+ return_dict: Optional[bool] = True,
756
+ ) -> Dict:
757
+
758
+ if attention_mask is None:
759
+ batch_size, seq_len = input_ids.shape
760
+ attention_mask = mx.ones((batch_size, seq_len))
761
+
762
+ outputs = self.model(
763
+ input_ids=input_ids,
764
+ attention_mask=attention_mask,
765
+ position_ids=position_ids,
766
+ output_hidden_states=output_hidden_states,
767
+ return_dict=return_dict,
768
+ )
769
+
770
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
771
+
772
+ # Pooling strategy
773
+ if self.config.classifier_pooling == "cls":
774
+ pooled = last_hidden_state[:, 0]
775
+ elif self.config.classifier_pooling == "mean":
776
+ pooled = mean_pooling(last_hidden_state, attention_mask)
777
+
778
+ # Apply head, dropout and classifier
779
+ pooled = self.head(pooled)
780
+ pooled = self.drop(pooled)
781
+ logits = self.classifier(pooled)
782
+
783
+ # Process logits for inference
784
+ processed_logits = self._process_outputs(logits)
785
+
786
+ loss = None
787
+ if labels is not None :
788
+ loss = self._compute_loss(logits, labels)
789
+
790
+ if not return_dict:
791
+ return [loss, processed_logits, outputs[1:]]
792
+
793
+ return {
794
+ "loss": loss,
795
+ "probabilities": processed_logits,
796
+ "hidden_states": outputs.get("hidden_states", None),
797
+ }
798
+
799
+ def sanitize(self, weights):
800
+ sanitized_weights = {}
801
+ for k, v in weights.items():
802
+ if "position_ids" in k:
803
+ # Remove unused position_ids
804
+ continue
805
+ if k in ["decoder.bias"]:
806
+ ### this is the hack
807
+ continue
808
+ elif k.startswith("bert"):
809
+ # Handle legacy BERT naming if needed
810
+ new_k = k.replace("bert.", "model.")
811
+ sanitized_weights[new_k] = v
812
+ else:
813
+ sanitized_weights[k] = v
814
+ return sanitized_weights
815
+
816
+ class ModelForTokenClassification(RaclateBaseModel):
817
+ """
818
+ Computes token classification probabilities for input sequences.
819
+
820
+ NOTE: untested for now
821
+ TODO : https://huggingface.co/disham993/electrical-ner-ModernBERT-base
822
+ """
823
+ def __init__(self, config: ModelArgs):
824
+ super().__init__()
825
+ self.config = config
826
+ self.num_labels = config.num_labels
827
+
828
+ self.model = ModernBertModel(config)
829
+ self.head = ModernBertPredictionHead(config)
830
+ self.drop = nn.Dropout(p=config.classifier_dropout)
831
+ self.classifier = nn.Linear(
832
+ config.hidden_size,
833
+ config.num_labels,
834
+ # bias=config.classifier_bias
835
+ )
836
+
837
+ # transformer architecture name for compatibility
838
+ self.hf_transformers_arch = "ModernBertForTokenClassification"
839
+
840
+
841
+ def __call__(
842
+ self,
843
+ input_ids,
844
+ attention_mask: Optional[mx.array] = None,
845
+ position_ids: Optional[mx.array] = None,
846
+ labels: Optional[mx.array] = None,
847
+ output_hidden_states: Optional[bool] = None,
848
+ return_dict: Optional[bool] = True,
849
+ ) -> Dict:
850
+ if attention_mask is None:
851
+ batch_size, seq_len = input_ids.shape
852
+ attention_mask = mx.ones((batch_size, seq_len))
853
+
854
+ outputs = self.model(
855
+ input_ids=input_ids,
856
+ attention_mask=attention_mask,
857
+ position_ids=position_ids,
858
+ output_hidden_states=output_hidden_states,
859
+ return_dict=return_dict,
860
+ )
861
+
862
+ last_hidden_state = outputs["last_hidden_state"] if return_dict else outputs[0]
863
+
864
+ # Apply prediction head, dropout, and classification layer to each token
865
+ sequence_output = self.head(last_hidden_state)
866
+ sequence_output = self.drop(sequence_output)
867
+ logits = self.classifier(sequence_output)
868
+
869
+ # Process logits for inference
870
+ processed_logits = mx.softmax(logits, axis=-1)
871
+
872
+ loss = None
873
+ if labels is not None:
874
+ # Compute token classification loss
875
+ loss = nn.losses.cross_entropy(
876
+ logits.reshape(-1, self.num_labels),
877
+ labels.reshape(-1)
878
+ )
879
+
880
+ if not return_dict:
881
+ return [loss, processed_logits, outputs[1:]]
882
+
883
+ return {
884
+ "loss": loss,
885
+ "probabilities": processed_logits,
886
+ "hidden_states": outputs.get("hidden_states", None),
887
+ }
888
+
889
+ def sanitize(self, weights):
890
+ sanitized_weights = {}
891
+ for k, v in weights.items():
892
+ if "position_ids" in k:
893
+ # Remove unused position_ids
894
+ continue
895
+ if k in ["decoder.bias"]:
896
+ ### this is the hack
897
+ continue
898
+ else:
899
+ sanitized_weights[k] = v
900
+ return sanitized_weights