mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,582 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+ import logging
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional, Union, Any, Literal
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+ from .base import (
9
+ BaseModelArgs,
10
+ RaclateBaseModel,
11
+ last_token_pooling,
12
+ normalize_embeddings,
13
+ compute_similarity_and_loss,
14
+ )
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @dataclass
20
+ class ModelArgs(BaseModelArgs):
21
+ architectures: List[str] = field(default_factory=lambda: ["Qwen3Model"])
22
+ attention_bias: Optional[bool] = False
23
+ attention_dropout: Optional[float] = 0.0
24
+ bos_token_id: Optional[int] = None
25
+ eos_token_id: Optional[int] = None
26
+ head_dim: int = 128
27
+ hidden_act: Optional[str] = "silu"
28
+ hidden_size: int = 1024
29
+ initializer_range: Optional[float] = (
30
+ 0.02 # Only needed in case of initializing weights
31
+ )
32
+ intermediate_size: int = 3072
33
+ max_position_embeddings: int = 32768
34
+ max_window_layers: Optional[int] = 28
35
+ model_type: str = "qwen3"
36
+ num_attention_heads: int = 16
37
+ num_hidden_layers: int = 28
38
+ num_key_value_heads: int = 8
39
+ rms_norm_eps: float = 1.0e-6
40
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
41
+ rope_theta: float = 1000000.0
42
+ tie_word_embeddings: bool = True
43
+ vocab_size: int = 151669
44
+
45
+ ### pipeline args
46
+ decoder_bias=True,
47
+ classifier_dropout=0.0
48
+ classifier_bias=False
49
+ sparse_prediction=True ### True seems a more appropriate value for MLM
50
+ sparse_pred_ignore_index=-100
51
+ is_regression: Optional[bool] = None
52
+ label2id: Optional[Dict[str, int]] = None
53
+ id2label: Optional[Dict[int, str]] = None
54
+ pipeline_config: Optional[Dict[str, Any]] = None # for Sequence Classification
55
+ use_late_interaction: bool = False
56
+
57
+ @property
58
+ def num_labels(self) -> int:
59
+ """
60
+ Number of labels is determined by:
61
+ - For zero-shot classification: length of label_candidates
62
+ - For regression or binary with sigmoid: 1
63
+ - For classification: length of id2label mapping
64
+ """
65
+
66
+ if self.is_regression:
67
+ return 1
68
+
69
+ if self.pipeline_config and self.pipeline_config.get("binary_sigmoid", False):
70
+ return 1
71
+
72
+ if self.id2label is None:
73
+ raise ValueError(
74
+ "id2label mapping must be provided for categorical classification. "
75
+ "For regression or binary classification with sigmoid output, "
76
+ "set is_regression=True or binary_sigmoid=True in pipeline_config."
77
+ )
78
+
79
+ return len(self.id2label)
80
+
81
+ def _sanitize_backbone(weights: Dict[str, Any]) -> Dict[str, Any]:
82
+ """
83
+ Standardizes keys for the Qwen3 embedding Backbone.
84
+ """
85
+ # no need for lm_head.weight in Qwen3 for embedding models
86
+ sanitized_weights = {}
87
+
88
+ for key, value in weights.items():
89
+ # Skip language model head weights (not used for embeddings)
90
+ if "lm_head.weight" in key or "classifier.weight" in key:
91
+ continue
92
+
93
+ # Handle different checkpoint formats
94
+ new_key = key
95
+
96
+ # Map common parameter naming patterns
97
+ if key.startswith("transformer."):
98
+ # Some checkpoints use "transformer." prefix
99
+ new_key = key.replace("transformer.", "model.")
100
+ # Handle weights without any prefix
101
+ elif not key.startswith("model.") and not key.startswith("score.") :
102
+ # Add model prefix for transformer parameters
103
+ new_key = f"model.{key}"
104
+ else:
105
+ # Keep as is for other parameters
106
+ new_key = key
107
+
108
+ sanitized_weights[new_key] = value
109
+
110
+ return sanitized_weights
111
+
112
+
113
+ class Attention(nn.Module):
114
+ def __init__(self, config: ModelArgs):
115
+ super().__init__()
116
+
117
+ dim = config.hidden_size
118
+ self.n_heads = n_heads = config.num_attention_heads
119
+ assert config.num_key_value_heads is not None
120
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
121
+
122
+ head_dim = config.head_dim
123
+ self.scale = head_dim**-0.5
124
+
125
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
126
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
127
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
128
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
129
+
130
+ self.q_norm = nn.RMSNorm(head_dim, eps=config.rms_norm_eps)
131
+ self.k_norm = nn.RMSNorm(head_dim, eps=config.rms_norm_eps)
132
+ self.rope = nn.RoPE(dims=head_dim, base=config.rope_theta)
133
+
134
+ def __call__(
135
+ self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None
136
+ ) -> mx.array:
137
+ B, L, D = hidden_states.shape
138
+
139
+ queries, keys, values = (
140
+ self.q_proj(hidden_states),
141
+ self.k_proj(hidden_states),
142
+ self.v_proj(hidden_states),
143
+ )
144
+
145
+ queries = self.q_norm(queries.reshape(B, L, self.n_heads, -1)).transpose(
146
+ 0, 2, 1, 3
147
+ )
148
+
149
+ keys = self.k_norm(keys.reshape(B, L, self.n_kv_heads, -1)).transpose(
150
+ 0, 2, 1, 3
151
+ )
152
+
153
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
154
+
155
+ queries = self.rope(queries)
156
+ keys = self.rope(keys)
157
+
158
+ output = mx.fast.scaled_dot_product_attention(
159
+ queries, keys, values, scale=self.scale, mask=attention_mask
160
+ )
161
+
162
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
163
+
164
+ hidden_states = self.o_proj(output)
165
+
166
+ return (hidden_states,)
167
+
168
+
169
+ class MLP(nn.Module):
170
+ def __init__(self, dim, hidden_dim):
171
+ super().__init__()
172
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
173
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
174
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
175
+
176
+ def __call__(self, x) -> mx.array:
177
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
178
+
179
+
180
+ class TransformerBlock(nn.Module):
181
+ def __init__(self, config: ModelArgs):
182
+ super().__init__()
183
+ self.num_attention_heads = config.num_attention_heads
184
+ self.hidden_size = config.hidden_size
185
+ self.self_attn = Attention(config)
186
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
187
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
188
+ self.post_attention_layernorm = nn.RMSNorm(
189
+ config.hidden_size, eps=config.rms_norm_eps
190
+ )
191
+ self.config = config
192
+
193
+ def __call__(
194
+ self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None
195
+ ) -> mx.array:
196
+ attention_output = self.self_attn(
197
+ self.input_layernorm(hidden_states), attention_mask
198
+ )
199
+ hidden_states = hidden_states + attention_output[0]
200
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
201
+ hidden_states = mlp_output + hidden_states
202
+ return (hidden_states,)
203
+
204
+
205
+ class Qwen3Model(nn.Module):
206
+ def __init__(self, config: ModelArgs):
207
+ super().__init__()
208
+ self.config = config
209
+ self.vocab_size = config.vocab_size
210
+ self.num_hidden_layers = config.num_hidden_layers
211
+ assert self.vocab_size > 0
212
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
213
+ self.layers = [
214
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
215
+ ]
216
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ return self.embed_tokens
220
+
221
+ def set_input_embeddings(self, value):
222
+ self.embed_tokens = value
223
+
224
+ def _update_attention_mask(self, attention_mask: Optional[mx.array] = None, dtype=None):
225
+ """
226
+ Creates a causal mask and combines it with the padding mask.
227
+ """
228
+
229
+ B, L = attention_mask.shape
230
+
231
+ causal_mask = mx.triu(mx.full((L, L), -1e9, dtype), k=1)
232
+
233
+ if attention_mask is not None:
234
+ # Reshape padding mask from (B, L) to (B, 1, 1, L) to be broadcastable
235
+ padding_mask = attention_mask[:, None, None, :]
236
+ additive_padding_mask = mx.where(padding_mask == 0, -1e9, 0.0).astype(dtype)
237
+
238
+ causal_mask = causal_mask + additive_padding_mask
239
+
240
+ return causal_mask.astype(dtype)
241
+
242
+ def __call__(
243
+ self,
244
+ input_ids: mx.array,
245
+ attention_mask: Optional[mx.array] = None,
246
+ output_hidden_states: Optional[bool] = False,
247
+ position_ids: Optional[mx.array] = None,
248
+ return_dict: Optional[bool] = True
249
+ ):
250
+
251
+ hidden_states = self.embed_tokens(input_ids)
252
+ model_dtype = hidden_states.dtype
253
+
254
+ attention_mask = self._update_attention_mask(
255
+ attention_mask=attention_mask,
256
+ dtype=model_dtype
257
+ )
258
+
259
+ for layer in self.layers:
260
+ layer_outputs = layer(hidden_states, attention_mask)
261
+ hidden_states = layer_outputs[0]
262
+
263
+ hidden_states = self.norm(hidden_states)
264
+
265
+ return {
266
+ "last_hidden_state": hidden_states,
267
+ }
268
+
269
+ # Not used for now
270
+ class Qwen3PredictionHead(nn.Module):
271
+ def __init__(self, config: ModelArgs):
272
+ super().__init__()
273
+ self.config = config
274
+ self.dense = nn.Linear(
275
+ config.hidden_size, config.hidden_size, config.classifier_bias
276
+ )
277
+ self.act = nn.GELU(approx="precise")
278
+ self.norm = nn.RMSNorm(
279
+ config.hidden_size, eps=config.rms_norm_eps
280
+ )
281
+
282
+ def __call__(self, hidden_states: mx.array) -> mx.array:
283
+ return self.norm(self.act(self.dense(hidden_states)))
284
+
285
+
286
+ class Model(RaclateBaseModel):
287
+ def __init__(self, config: ModelArgs):
288
+ super().__init__()
289
+ self.config = config
290
+ self.model = Qwen3Model(config)
291
+
292
+ # transformer architecture name for compatibility
293
+ self.hf_transformers_arch = "Qwen3ForCausalLM"
294
+
295
+ def __call__(
296
+ self,
297
+ input_ids: mx.array,
298
+ position_ids: Optional[mx.array] = None,
299
+ attention_mask: Optional[mx.array] = None,
300
+ output_hidden_states: Optional[bool] = False,
301
+ return_dict: Optional[bool] = True,
302
+ ) -> Dict:
303
+ if attention_mask is None:
304
+ batch_size, seq_len = input_ids.shape
305
+ attention_mask = mx.ones(
306
+ (batch_size, seq_len),
307
+ dtype=self.model.embed_tokens.weight.dtype,
308
+ )
309
+
310
+ out = self.model(input_ids, attention_mask)
311
+ last_hidden_state = (
312
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
313
+ )
314
+
315
+ # pooling for AR models such as Qwen3 leverages the last token
316
+ pooled_embeddings = last_token_pooling(last_hidden_state, attention_mask)
317
+ text_embeds = normalize_embeddings(pooled_embeddings)
318
+
319
+ if not return_dict:
320
+ return (text_embeds, last_hidden_state)
321
+
322
+ return {
323
+ "embeddings": text_embeds, # normalized embeddings
324
+ "last_hidden_state": last_hidden_state,
325
+ }
326
+
327
+ def sanitize(self, weights):
328
+
329
+ sanitized_weights = _sanitize_backbone(weights)
330
+
331
+ # Handle SentenceTransformer specific keys
332
+ final_weights = {}
333
+ for k, v in sanitized_weights.items():
334
+
335
+ if not k.startswith("model."):
336
+ continue
337
+ final_weights[k] = v
338
+
339
+ return final_weights
340
+
341
+
342
+ class ModelForSentenceSimilarity(RaclateBaseModel):
343
+ """
344
+ Computes similarity scores between input sequences and reference sentences.
345
+ """
346
+ def __init__(self, config : ModelArgs):
347
+ super().__init__()
348
+ self.config = config
349
+ self.model_type = config.model_type # not used for now (placeholder)
350
+ self.model = Qwen3Model(config)
351
+
352
+ def _call_model(self, input_ids, attention_mask=None, return_dict=True):
353
+ out = self.model(input_ids, attention_mask)
354
+ last_hidden_state = (
355
+ out["last_hidden_state"] if isinstance(out, dict) else out[0]
356
+ )
357
+
358
+ # text_embeds = normalize_embeddings(last_hidden_state)
359
+ if self.config.use_late_interaction:
360
+ text_embeds = normalize_embeddings(last_hidden_state)
361
+ # Keep unpooled for ColBERT style
362
+ # Mask padding tokens to avoid them affecting MaxSim
363
+ if attention_mask is not None:
364
+ text_embeds = text_embeds * attention_mask[..., None]
365
+ else:
366
+ # Standard causal model retrieval: Last Token Pooling
367
+ text_embeds = last_token_pooling(last_hidden_state, attention_mask)
368
+ text_embeds = normalize_embeddings(text_embeds)
369
+
370
+ if not return_dict:
371
+ return (text_embeds, last_hidden_state)
372
+
373
+ return {
374
+ "embeddings": text_embeds, # normalized embeddings
375
+ "last_hidden_state": last_hidden_state,
376
+ }
377
+
378
+ def __call__(
379
+ self,
380
+ input_ids,
381
+ reference_input_ids : Optional[mx.array] = None, # Shape: [num_references, seq_len]
382
+ negative_input_ids : Optional[mx.array] = None, # Shape: [num_negatives, seq_len]
383
+ attention_mask: Optional[mx.array] = None,
384
+ reference_attention_mask: Optional[mx.array] = None,
385
+ negative_attention_mask: Optional[mx.array] = None,
386
+ similarity_scores: Optional[mx.array] = None, # Shape: [batch_size, num_references]
387
+ position_ids: Optional[mx.array] = None,
388
+ return_dict: Optional[bool] = True,
389
+ ):
390
+
391
+ if attention_mask is None:
392
+ batch_size, seq_len = input_ids.shape
393
+ attention_mask = mx.ones(
394
+ (batch_size, seq_len),
395
+ dtype=self.model.embed_tokens.weight.dtype,
396
+ )
397
+
398
+ # Get embeddings for input batch
399
+ batch_outputs = self._call_model(
400
+ input_ids=input_ids,
401
+ attention_mask=attention_mask,
402
+ return_dict=True
403
+ )
404
+ embeddings = batch_outputs["embeddings"] # [batch_size, hidden_size]
405
+
406
+ loss = None
407
+ similarities = None
408
+
409
+ if reference_input_ids is not None:
410
+
411
+ # Get embeddings for reference sentences
412
+ ref_outputs = self._call_model(
413
+ input_ids=reference_input_ids,
414
+ attention_mask=reference_attention_mask,
415
+ return_dict=True
416
+ )
417
+ reference_embeddings = ref_outputs["embeddings"] # [num_references, hidden_size]
418
+
419
+ similarities, loss = compute_similarity_and_loss(
420
+ self.config,
421
+ input_ids,
422
+ embeddings,
423
+ reference_embeddings,
424
+ self._call_model,
425
+ similarity_scores,
426
+ negative_input_ids,
427
+ negative_attention_mask,
428
+ )
429
+
430
+ if not return_dict:
431
+ return (loss, similarities, embeddings)
432
+
433
+ return {
434
+ "loss": loss,
435
+ "similarities": similarities, # [batch_size, num_references]
436
+ "embeddings": embeddings, # [batch_size, hidden_size]
437
+ }
438
+
439
+ def sanitize(self, weights):
440
+
441
+ sanitized_weights = _sanitize_backbone(weights)
442
+
443
+ # Handle SentenceTransformer specific keys
444
+ final_weights = {}
445
+ for k, v in sanitized_weights.items():
446
+
447
+ if not k.startswith("model."):
448
+ continue
449
+ final_weights[k] = v
450
+
451
+ return final_weights
452
+
453
+ class ModelForSentenceTransformers(ModelForSentenceSimilarity):
454
+ """
455
+ Extends ModelForSentenceSimilarity to provide embeddings for input sequences.
456
+ This class sanitizes typical sentence transformers weights to align with the Qwen3 model.
457
+ """
458
+ def __init__(self, config: ModelArgs):
459
+ super().__init__(config)
460
+
461
+ def sanitize(self, weights):
462
+ """Convert sentence transformer weights to Qwen3 format."""
463
+ sanitized_weights = {}
464
+
465
+ for k, v in weights.items():
466
+ if "position_ids" in k:
467
+ # Remove unused position_ids
468
+ continue
469
+ else:
470
+ new_key = "model." + k
471
+ sanitized_weights[new_key] = v
472
+ return sanitized_weights
473
+
474
+ class ModelForSequenceClassification(RaclateBaseModel):
475
+ """
476
+ Computes sequence classification probabilities for input sequences.
477
+ Sanitization aligns typical BERT weights with HF's Qwen3ForSequenceClassification architecture.
478
+
479
+ NOTE : regression and binary classification not tested.
480
+ """
481
+ def __init__(self, config: ModelArgs):
482
+ super().__init__()
483
+ self.config = config
484
+ self.num_labels = config.num_labels
485
+ self.is_regression = config.is_regression
486
+
487
+ self.model = Qwen3Model(config)
488
+
489
+ ### The HF architecture Qwen3ForSequenceClassification
490
+ ### does not have head and drop
491
+ #### and uses 'score' as the final layer name
492
+ # self.head = Qwen3PredictionHead(config)
493
+ # self.drop = nn.Dropout(p=config.classifier_dropout)
494
+
495
+ self.score = nn.Linear(
496
+ config.hidden_size,
497
+ config.num_labels,
498
+ bias=False
499
+ )
500
+
501
+ self.hf_transformers_arch = "Qwen3ForSequenceClassification"
502
+
503
+ def _process_outputs(self, logits: mx.array) -> mx.array:
504
+ """Apply the appropriate activation function to the logits."""
505
+ if self.is_regression:
506
+ return logits # No activation for regression
507
+ elif self.num_labels == 1:
508
+ return mx.sigmoid(logits) # Binary classification
509
+ else:
510
+ # Using softmax for multi-class classification
511
+ return mx.softmax(logits, axis=-1)
512
+
513
+ def _compute_loss(self, logits: mx.array, labels: mx.array) -> mx.array:
514
+ """Compute the appropriate loss based on label characteristics."""
515
+ if self.is_regression:
516
+ return nn.losses.mse_loss(logits.squeeze(), labels.squeeze())
517
+ elif self.num_labels == 1:
518
+ return nn.losses.binary_cross_entropy(mx.sigmoid(logits), labels)
519
+ else:
520
+ return nn.losses.cross_entropy(
521
+ logits.reshape(-1, self.num_labels),
522
+ labels.reshape(-1)
523
+ )
524
+
525
+ def __call__(
526
+ self,
527
+ input_ids,
528
+ attention_mask: Optional[mx.array] = None,
529
+ position_ids: Optional[mx.array] = None, ### need this?
530
+ labels: Optional[mx.array] = None,
531
+ output_hidden_states: Optional[bool] = False,
532
+ return_dict: Optional[bool] = True,
533
+ ) -> Dict:
534
+ if attention_mask is None:
535
+ batch_size, seq_len = input_ids.shape
536
+ attention_mask = mx.ones(
537
+ (batch_size, seq_len),
538
+ dtype=self.model.embed_tokens.weight.dtype,
539
+ )
540
+
541
+ outputs = self.model(
542
+ input_ids,
543
+ attention_mask,
544
+ position_ids=position_ids,
545
+ output_hidden_states=output_hidden_states,
546
+ return_dict=return_dict
547
+ )
548
+ last_hidden_state = (
549
+ outputs["last_hidden_state"] if isinstance(outputs, dict) else outputs[0]
550
+ )
551
+
552
+ # pooling for AR models such as Qwen3 leverages the last token
553
+ pooled = last_token_pooling(last_hidden_state, attention_mask)
554
+
555
+ ### The HF architecture Qwen3ForSequenceClassification
556
+ ### does not have head and drop
557
+ #### and uses 'score' as the final layer name
558
+ # pooled = self.head(pooled)
559
+ # pooled = self.drop(pooled)
560
+ logits = self.score(pooled)
561
+
562
+ processed_logits = self._process_outputs(logits)
563
+
564
+ loss = None
565
+ if labels is not None :
566
+ loss = self._compute_loss(logits, labels)
567
+
568
+ if not return_dict:
569
+ return [loss, processed_logits, outputs[1:]]
570
+
571
+ return {
572
+ "loss": loss,
573
+ "probabilities": processed_logits,
574
+ "hidden_states": outputs.get("hidden_states", None),
575
+ }
576
+
577
+ def sanitize(self, weights):
578
+
579
+ return _sanitize_backbone(weights)
580
+
581
+ # TokenClassification and MaskedLM not implemented for now AR models such as Qwen3
582
+ # Attempting to train pretrained weights would be catastrophic