mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlx_raclate/py.typed ADDED
File without changes
@@ -0,0 +1,305 @@
1
+ # RACLATE Tuner
2
+
3
+ The `tuner` module is the training engine of RACLATE (**R**etrieval **A**nd **C**lassification including **LATE** interaction models). Built entirely on Apple's [MLX](https://github.com/ml-explore/mlx) framework, it provides a highly efficient, unified interface for fine-tuning *small* Transformer-based classifiers on Apple Silicon.
4
+
5
+ This trainer supports standard dense retrieval, classification, and masked language modeling, as well as **Late Interaction (ColBERT-style)** training patterns.
6
+
7
+ ## Key Features
8
+
9
+ * **Apple Silicon Native:** Fully optimized for M-series chips using MLX.
10
+ * **Full Training:** Full fine-tuning of pretrained models (_see supported architectures below_). LORA fine-tuning is not supported (yet). The library allows transfer learning, meaning that existing heads can be stripped out of pretrained models (and new heads can be added to base models for specific tasks)
11
+ * **Memory Efficiency:** Built-in support for **Gradient Accumulation** and **Gradient Checkpointing** to train larger batches/models on limited Unified Memory.
12
+ * **Flexible Schedulers:** Linear, Cosine, and Constant learning rate schedules with warmup.
13
+ * **Smart Collators:** Task-specific data collators that handle padding, masking, and chat templates automatically.
14
+ * **Embedding Freezing:** Option to freeze embedding layers to speed up fine-tuning or prevent catastrophic forgetting.
15
+ * **HF Hub Integration (TODO):** Seamless saving and pushing of checkpoints to the Hugging Face Hub.
16
+
17
+ ## Supported Architectures
18
+
19
+ The trainer supports a variety of modern architectures supporting long context (relative to BERT models). As these models are meant to be trained and run on local machines, model implementations are specifically optimized for small-to-mid-sized models:
20
+
21
+ * **ModernBERT**: MLX implementation of `answerdotai/ModernBERT-base` (encoder-only). Long context (8k) and high efficiency.
22
+ * **Qwen 3**: MLX implementation of `Qwen/Qwen3-Embedding-0.6B` (32k context window) which leverages the qwen3 architecture.
23
+ * **Gemma 3**: MLX implementation of `google/embeddinggemma-300m` (2k context window) which leverages the gemma3 text variant architecture with a few tweaks. As per the official embeddinggemma3 architecture, the attention mask is set to causal or bi-directional based on a config parameter (`use_bidirectional_attn` or `use_bidirectional_attention`). Therefore, it is possible to switch between encoder and decoder mode, and standard gemma3_text models (32k context window) are also supported.
24
+ * **T5Gemma-Encoder**: MLX implementation of `google/t5gemma-b-b-ul2`, but only keeping the encoder weights at initialization (the encoder config is merged into the main model config)
25
+ * **LFM2**: MLX implementation of `LiquidAI/LFM2-350M` (Causal/AR) which also supports `LiquidAI/LFM2-ColBERT-350M` when model config file includes `use_late_interaction=True`. These models have a context window of 128k tokens. In training mode, 128k tokens exceeds the RAM capacity of most Apple hardware. _See parameters below to cap sequences to a more reasonable length during training_
26
+
27
+
28
+ ## Supported Tasks & Pipelines
29
+
30
+ The `Trainer` adapts its logic based on the `task_type` and the specific model class initialized.
31
+
32
+ ### 1. Sentence Similarity (Embedding & Retrieval)
33
+ Train models for semantic search, clustering, or RAG.
34
+ * **Task Type:** `sentence-similarity`
35
+ * **Training Modes:**
36
+ * **Bi-Encoder (Dense):** Standard cosine similarity optimization.
37
+ * **Late Interaction (MaxSim):** ColBERT-style interaction where fine-grained token-level similarities are computed (requires `use_late_interaction=True`).
38
+ * **Loss Functions:** Automatically selects between **MNRL (Multiple Negatives Ranking Loss)** for triplets/pairs or **MSE/Cosine Loss** for scored pairs.
39
+
40
+ ### 2. Sequence Classification
41
+ Train discriminative models for sentiment analysis, intent detection, etc.
42
+ * **Task Type:** `text-classification`
43
+ * **Features:**
44
+ * Supports Multi-class and Binary classification.
45
+ * Supports Regression (if `is_regression=True`).
46
+ * Native support for Chat Templates in tokenizer.
47
+
48
+ ### 3. Masked Language Modeling (MLM)
49
+ Perform domain adaptation on raw text.
50
+ * **Task Type:** `masked-lm`
51
+ * **Features:** Implements the standard 80% mask / 10% random / 10% original masking strategy dynamically during training.
52
+
53
+ ### 4. Token Classification (NER/POS)
54
+ Named Entity Recognition and Part-of-Speech tagging.
55
+ * **Task Type:** `token-classification`
56
+ * **Features:** Handles label alignment for sub-word tokens automatically.
57
+
58
+
59
+ ## Data Preparation
60
+
61
+ The `datasets.py` module handles loading (JSONL, Parquet, CSV, HF Hub) and column standardization. It is built on top of HuggingFace's datasets.
62
+
63
+ ### 1. Column Mapping
64
+ The trainer looks for specific column names.
65
+
66
+ | Task | Required Columns | Description |
67
+ | :--- | :--- | :--- |
68
+ | **Classification** | `text`, `label` | Input text and target class/score. |
69
+ | **Pairs (Sim.)** | `text`, `text_pair` | Anchor and Positive/Candidate. |
70
+ | **Triplets** | `text`, `text_pair`, `negative` | Anchor, Positive, Hard Negative. |
71
+ | **MLM** | `text` | Raw text for masking. |
72
+ | **NER** | `tokens`, `labels` | Pre-tokenized words and aligned tags. |
73
+
74
+ *Note: For Sentence Similarity, if a `label` column is present with floats, the trainer switches to Regression/MSE loss (e.g., for scored Bi-Encoders).*
75
+
76
+ You can map your custom dataset fields via `DatasetArgs`.
77
+ ```
78
+ # Load datasets
79
+ dataset_args = DatasetArgs(
80
+ data=dataset, # dataset path
81
+ task_type=task_type,
82
+ text_field="question", # maps column 'question' to 'text'
83
+ text_pair_field="response", # maps column 'response' to 'text_pair'
84
+ negative_field="semantically_different_response", # maps column 'semantically_different_response' to 'negative'
85
+ label_field="classification", # maps column 'classification to 'label'
86
+ test=True # creates a test split, if not already present in the dataset, out of the training set (validation set not affected).
87
+ )
88
+ ```
89
+ Anchor is automatically mapped to `text` and Positive is automatically mapped to `text_pair`. See _standardize_column_names() in `datasets.py` for more information on column mapping.
90
+
91
+ ### 2. Text Pairs and Chat Template
92
+
93
+ For certain tasks like text-classification, you may want to classify how two token sequences (text and text_pair) relate to each other.
94
+
95
+ For bi-encoders, it is highly recommended to let the tokenizer combine the text and the text_pair rather than aggregating them manually. This ensures that the correct separation token is used.
96
+
97
+ ```
98
+ batch = self.tokenizer(
99
+ texts,
100
+ text_pairs,
101
+ padding="longest",
102
+ truncation=True,
103
+ max_length=self.max_length,
104
+ return_tensors="mlx"
105
+ )
106
+ ```
107
+
108
+ In some cases, you may want to use the chat template that was used to train the model you intend to finetune. For example, LFM2-350M recommends using a chat template.
109
+
110
+ If `use_chat_template` is set to True when initializing the training (default False) and if a chat template is available in the tokenizer (do check!), the text and the text_pair values will be combined and text_pair will be set to None.
111
+
112
+ You can also force a specific string as separator.
113
+
114
+ This is how it works under the hood:
115
+
116
+ ```
117
+ if text_pairs is not None:
118
+ if getattr(self.tokenizer, "chat_template", None) and self.use_chat_template:
119
+ # This ensures the model sees exactly what it expects for Q&A
120
+ formatted_texts = []
121
+ for prompt, response in zip(texts, text_pairs):
122
+ messages = [
123
+ {"role": "user", "content": prompt},
124
+ {"role": "assistant", "content": response}
125
+ ]
126
+ formatted_texts.append(
127
+ self.tokenizer.apply_chat_template(messages, tokenize=False)
128
+ )
129
+ texts = formatted_texts
130
+ text_pairs = None # Handled by template
131
+
132
+ elif self.force_separator is not None:
133
+ # Use the forced separator for decoder models
134
+ texts = [
135
+ f"{t}{self.force_separator}{p}"
136
+ for t, p in zip(texts, text_pairs)
137
+ ]
138
+ text_pairs = None
139
+ ```
140
+
141
+ See DataCollatorForSequenceClassification in `collators.py` for more information on text_pair handling for text-classification.
142
+
143
+
144
+ ## Quick Start (Programmatic)
145
+
146
+ Below is a simplified example of how to set up a training run programmatically.
147
+
148
+ ```python
149
+ from mlx_raclate.utils.utils import load
150
+ from mlx_raclate.tuner.datasets import load_dataset, DatasetArgs
151
+ from mlx_raclate.tuner.trainer import Trainer, TrainingArgs
152
+
153
+ # 1. Configuration variables
154
+ model_path = "Qwen/Qwen3-Embedding-0.6B"
155
+ dataset_path = "data/wines"
156
+ task_type = "text-classification"
157
+
158
+ # 2. Load and Prepare Dataset
159
+ dataset_args = DatasetArgs(
160
+ data=dataset_path,
161
+ task_type=task_type,
162
+ # Optional: override field names if your data isn't standard
163
+ # text_field="question",
164
+ # text_pair_field="response",
165
+ # label_field="classification"
166
+ )
167
+ train_ds, valid_ds, test_ds, id2label, label2id = load_dataset(dataset_args)
168
+
169
+ # 3. Load Model and Tokenizer
170
+ # Pass label mappings to model config for classification tasks
171
+ model_config = {"id2label": id2label, "label2id": label2id} if id2label else {}
172
+
173
+ model, tokenizer = load(
174
+ model_path,
175
+ model_config=model_config,
176
+ pipeline=task_type,
177
+ train=True
178
+ )
179
+
180
+ # 4. Define Training Arguments
181
+ args = TrainingArgs(
182
+ output_dir="outputs/my_run",
183
+ batch_size=4,
184
+ gradient_accumulation_steps=4,
185
+ learning_rate=1e-5,
186
+ num_train_epochs=3,
187
+ lr_scheduler_type="cosine_decay",
188
+ warmup_ratio=0.03,
189
+ save_steps=500,
190
+ logging_steps=10,
191
+ max_length=2048,
192
+ freeze_embeddings=False
193
+ )
194
+
195
+ # 5. Initialize Trainer
196
+ trainer = Trainer(
197
+ model=model,
198
+ tokenizer=tokenizer,
199
+ task_type=task_type,
200
+ training_args=args,
201
+ train_dataset=train_ds,
202
+ eval_dataset=valid_ds,
203
+ label2id=label2id,
204
+ # For decoder models doing classification on pairs:
205
+ use_chat_template=False
206
+ )
207
+
208
+ # 6. Run Training
209
+ trainer.train()
210
+
211
+ # 7. Evaluate on Test Set (Optional)
212
+ if test_ds:
213
+ trainer.test(test_ds)
214
+ ```
215
+
216
+ ## CLI usage
217
+
218
+ An example of CLI tool including **all** parameters to train a model is available in `mlx_raclate.utils.train.py`.
219
+
220
+ WARNING : this example includes default values that override the default values of the DatasetArgs, TrainingArgs and Trainer classes presented below.
221
+
222
+ ## API Reference
223
+
224
+ ### DatasetArgs
225
+
226
+ Used to configure how data is loaded and mapped.
227
+
228
+ | Parameter | Type | Default | Description |
229
+ | :--- | :--- | :--- | :--- |
230
+ | `data` | `str` | *Required* | Local path or HF identifier of the dataset. |
231
+ | `task_type` | `str` | *Required* | The type of task (e.g., `text-classification`). |
232
+ | `text_field` | `str` | `None` | Name of the text input column. |
233
+ | `text_pair_field`| `str` | `None` | Name of the second text input column (for pairs). |
234
+ | `label_field` | `str` | `None` | Name of the label/target column. |
235
+ | `negative_field`| `str` | `None` | Name of the negative samples column. |
236
+ | `test` | `bool` | `False` | If True, creates a test split from the training set if one doesn't exist. |
237
+
238
+ Note : use load_dataset("dataset_path") from `datasets.py` to fetch the dataset splits and the label2id dictionary.
239
+
240
+ ### TrainingArgs
241
+
242
+ Controls the hyperparameters and runtime configuration.
243
+
244
+ #### Hyperparameters
245
+ | Parameter | Default | Description |
246
+ | :--- | :--- | :--- |
247
+ | `batch_size` | `2` | The physical batch size per device/step. Reduce to 1 on Macbooks with limited RAM if training on long context. |
248
+ | `gradient_accumulation_steps` | `8` | Number of steps to accumulate gradients before updating weights. |
249
+ | `num_train_epochs` | `2` | Total number of training epochs. |
250
+ | `max_length` | `512` | Max sequence length. If `None`, uses model's default config. |
251
+ | `freeze_embeddings` | `False` | If `True`, freezes the embedding layer to save memory/compute. |
252
+
253
+ #### Optimizer & Scheduler
254
+ | Parameter | Default | Description |
255
+ | :--- | :--- | :--- |
256
+ | `learning_rate` | `3e-5` | Initial learning rate (Peak LR). |
257
+ | `weight_decay` | `0.01` | Weight decay factor for AdamW. |
258
+ | `lr_scheduler_type` | `"constant"` | Scheduler type: `"cosine_decay"`, `"linear_schedule"`, or `"constant"`. |
259
+ | `min_lr` | `0.0` | Minimum learning rate at the end of the schedule. |
260
+ | `warmup_ratio` | `0.0` | Ratio of total training steps used for warmup. |
261
+ | `warmup_steps` | `0` | Absolute number of warmup steps (overrides `warmup_ratio` if set). |
262
+ | `max_grad_norm` | `1.0` | Gradient clipping threshold. |
263
+
264
+ #### Checkpointing & Logging
265
+ | Parameter | Default | Description |
266
+ | :--- | :--- | :--- |
267
+ | `output_dir` | `None` | Directory to save checkpoints and logs. Defaults to a timestamped folder. |
268
+ | `save_steps` | `1000` | Frequency of saving model checkpoints (in steps). |
269
+ | `logging_steps` | `16` | Frequency of logging metrics to console/files. |
270
+ | `eval_batch_size` | `4` | Batch size used during evaluation/testing. |
271
+ | `resume_from_step`| `0` | Step to resume training from. If this is after the last warmup step (either declared or calculated via warmup_ratio), warmup will be ignored. |
272
+
273
+ Gradient checkpointing is enabled by default due to RAM constraints of consumer hardware.
274
+
275
+ ### Model Config
276
+
277
+ When loading a pretrained model, you can create a model_config dictionary with new parameters and pass it to the load() function. Common examples :
278
+
279
+ | Parameter | Type | Default | Description |
280
+ | :--- | :--- | :--- | :--- |
281
+ | `is_regression` | `bool` | `False` | For text-classification tasks, whether the classification is a regression |
282
+ | `use_late_interaction`| `bool` | `False` | For sentence similarity tasks, whether late interaction (MaxSim) should be used instead of cosine similarity |
283
+
284
+ ## Trainer
285
+
286
+ The main class that orchestrates the training.
287
+
288
+ **Constructor Parameters:**
289
+
290
+ * **`model`**: The loaded MLX model.
291
+ * **`tokenizer`**: The loaded tokenizer. If you want to use a chat template, make sure that the tokenizer includes the chat template. If not, add it manually before instantiating the Trainer.
292
+ * **`task_type`**: String identifier for the pipeline (e.g., "text-classification").
293
+ * **`training_args`**: Instance of `TrainingArgs`.
294
+ * **`train_dataset`**: The processed training dataset.
295
+ * **`eval_dataset`**: (Optional) The processed validation dataset.
296
+ * **`label2id`**: (Optional) Dictionary mapping labels to IDs (required for classification metrics).
297
+ * **`use_chat_template`** *(bool)*: If `True`, applies the tokenizer's chat template to inputs. Useful for decoder models (like Qwen/Llama) performing classification on text pairs.
298
+ * **`force_separator`** (Optional *str*): If not using a chat template, this string is used to join text pairs for decoder models.
299
+ * **`optimizer`** (Optional *mlx.optimizer*): If no optimizer is passed, AdamW will be used with the hyper parameters set in TrainingArgs
300
+
301
+ **Methods:**
302
+
303
+ * `train()`: Starts the training loop.
304
+ * `test(dataset)`: Runs evaluation on the provided dataset.
305
+
File without changes
@@ -0,0 +1,291 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+ from typing import List, Dict, Any, Optional, Union
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass
7
+ class DataCollator:
8
+ tokenizer: Any
9
+ max_length: int = 512
10
+
11
+ def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
12
+ raise NotImplementedError
13
+
14
+ @dataclass
15
+ class DataCollatorForSequenceClassification(DataCollator):
16
+ """
17
+ Handles tokenization and padding for classification tasks.
18
+ """
19
+ use_chat_template: bool = False# Whether to use chat templates for decoder models
20
+ force_separator: Optional[str] = None # If set, forces this separator between text pairs
21
+ default_decoder_separator: str = "\n" # Used for decoder models when concatenating text pairs
22
+ label2id: Optional[Dict[str, int]] = None
23
+
24
+ def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
25
+ texts = features.get("text")
26
+ text_pairs = features.get("text_pair", None)
27
+
28
+ if text_pairs is not None:
29
+ if getattr(self.tokenizer, "chat_template", None) and self.use_chat_template:
30
+ # This ensures the model sees exactly what it expects for Q&A
31
+ formatted_texts = []
32
+ for prompt, response in zip(texts, text_pairs):
33
+ messages = [
34
+ {"role": "user", "content": prompt},
35
+ {"role": "assistant", "content": response}
36
+ ]
37
+ formatted_texts.append(
38
+ self.tokenizer.apply_chat_template(messages, tokenize=False)
39
+ )
40
+ texts = formatted_texts
41
+ text_pairs = None # Handled by template
42
+
43
+ elif self.force_separator is not None:
44
+ # Use the forced separator for decoder models
45
+ texts = [
46
+ f"{t}{self.force_separator}{p}"
47
+ for t, p in zip(texts, text_pairs)
48
+ ]
49
+ text_pairs = None
50
+
51
+ else :
52
+ # Check if tokenizer has a standard separator (Like [SEP] in BERT)
53
+ # Qwen tokenizer often has sep_token as None or same as EOS
54
+ has_sep_token = getattr(self.tokenizer, "sep_token", None) is not None
55
+
56
+ if not has_sep_token or self.tokenizer.sep_token == self.tokenizer.eos_token:
57
+ texts = [
58
+ f"{t}{self.default_decoder_separator}{p}"
59
+ for t, p in zip(texts, text_pairs)
60
+ ]
61
+ # Set pairs to None so tokenizer treats it as a single string
62
+ text_pairs = None
63
+
64
+ if self.tokenizer.pad_token_id is None:
65
+ self.tokenizer.pad_token = self.tokenizer.eos_token
66
+
67
+ batch = self.tokenizer(
68
+ texts,
69
+ text_pairs,
70
+ padding="longest",
71
+ truncation=True,
72
+ max_length=self.max_length,
73
+ return_tensors="mlx"
74
+ )
75
+
76
+ if "label" in features:
77
+ labels = features["label"]
78
+ # On-the-fly String to ID conversion
79
+ if self.label2id and len(labels) > 0 and isinstance(labels[0], str):
80
+ labels = [self.label2id.get(l, -1) for l in labels] # Default to -1 if missing
81
+
82
+ # Detect regression (float) vs classification (int)
83
+ if len(labels) > 0 and isinstance(labels[0], float):
84
+ dtype = mx.float32
85
+ else:
86
+ dtype = mx.int32
87
+
88
+ batch["labels"] = mx.array(labels, dtype=dtype)
89
+
90
+ return dict(batch)
91
+
92
+ @dataclass
93
+ class DataCollatorForTokenClassification(DataCollator):
94
+ """
95
+ Handles tokenization and aligns labels for token classification.
96
+ """
97
+ label_pad_token_id: int = -100
98
+ # Strategy: 'first' (label only first subword), 'all' (label all subwords with same tag)
99
+ label_all_tokens: bool = False
100
+ label2id: Optional[Dict[str, int]] = None
101
+
102
+ def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
103
+ texts = features["text"]
104
+ labels = features["labels"] # Note: usually plural 'labels' list of list
105
+
106
+ # SANITY CHECK: The library expects pre-tokenized inputs (List[str])
107
+ if isinstance(texts[0], str):
108
+ raise ValueError(
109
+ "DataCollatorForTokenClassification expects 'text' to be a list of strings "
110
+ "(tokens), not a single string. Please pre-tokenize your dataset."
111
+ )
112
+
113
+ batch = self.tokenizer(
114
+ texts,
115
+ padding=True,
116
+ truncation=True,
117
+ max_length=self.max_length,
118
+ return_tensors="mlx",
119
+ is_split_into_words=True
120
+ )
121
+
122
+ batch_size, seq_len = batch["input_ids"].shape
123
+
124
+ # Create a numpy buffer filled with the ignore index
125
+ padded_labels = np.full((batch_size, seq_len), self.label_pad_token_id, dtype=np.int32)
126
+
127
+ for i, label_seq in enumerate(labels):
128
+ # On-the-fly conversion for list of strings (to avoid memory issues with dataset.map)
129
+ current_labels = label_seq
130
+ if self.label2id and len(label_seq) > 0 and isinstance(label_seq[0], str):
131
+ current_labels = [self.label2id.get(l, self.label_pad_token_id) for l in label_seq]
132
+
133
+ # word_ids returns a list mapping each token to its original word index
134
+ # e.g., [None, 0, 1, 1, 2, None] for "[CLS] My name is John [SEP]"
135
+ word_ids = batch.word_ids(batch_index=i)
136
+ previous_word_idx = None
137
+
138
+ for k, word_idx in enumerate(word_ids):
139
+ # Skip Special Tokens (None)
140
+ if word_idx is None:
141
+ continue
142
+
143
+ # Safety check: tokenizer truncation might leave word_ids that point to label indices larger than the label list provided.
144
+ if word_idx >= len(current_labels):
145
+ break
146
+
147
+ if word_idx != previous_word_idx:
148
+ padded_labels[i, k] = current_labels[word_idx]
149
+ else:
150
+ # This is a subsequent subword of the same word
151
+ if self.label_all_tokens:
152
+ padded_labels[i, k] = current_labels[word_idx]
153
+ else:
154
+ # Standard BERT NER behavior: ignore subsequent subwords
155
+ padded_labels[i, k] = self.label_pad_token_id
156
+
157
+ previous_word_idx = word_idx
158
+
159
+ batch["labels"] = mx.array(padded_labels, dtype=mx.int32)
160
+
161
+ return dict(batch)
162
+
163
+ @dataclass
164
+ class DataCollatorForMaskedLanguageModeling(DataCollator):
165
+ """
166
+ Handles dynamic masking for MLM.
167
+ """
168
+ mlm_probability: float = 0.15
169
+ mask_token_id: Optional[int] = None
170
+
171
+ def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
172
+ texts = features["text"]
173
+
174
+ batch = self.tokenizer(
175
+ texts,
176
+ padding=True,
177
+ truncation=True,
178
+ max_length=self.max_length,
179
+ return_tensors="mlx"
180
+ )
181
+
182
+ input_ids = batch["input_ids"]
183
+
184
+ # Create Mask
185
+ probability_matrix = mx.random.uniform(shape=input_ids.shape) < self.mlm_probability
186
+
187
+ # Protect special tokens
188
+ special_tokens_mask = mx.array([
189
+ [1 if token_id in [self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id]
190
+ else 0 for token_id in seq]
191
+ for seq in input_ids.tolist()
192
+ ])
193
+
194
+ probability_matrix = mx.where(special_tokens_mask, 0, probability_matrix)
195
+
196
+ # Create labels (-100 for unmasked)
197
+ labels = mx.where(probability_matrix, input_ids, -100)
198
+
199
+ # Apply masking (80% mask, 10% random, 10% original)
200
+ random_matrix = mx.random.uniform(shape=input_ids.shape)
201
+ mask_indices = (probability_matrix) & (random_matrix < 0.8)
202
+ random_indices = (probability_matrix) & (random_matrix >= 0.8) & (random_matrix < 0.9)
203
+
204
+ # Create masked input
205
+ masked_inputs = input_ids
206
+
207
+ mask_token_id = self.tokenizer.mask_token_id
208
+ if mask_token_id is None:
209
+ if self.mask_token_id is not None:
210
+ mask_token_id = self.mask_token_id
211
+ else:
212
+ raise ValueError(
213
+ "Tokenizer does not have a mask token defined and no mask_token_id provided."
214
+ )
215
+
216
+ masked_inputs = mx.where(mask_indices, mask_token_id, masked_inputs)
217
+ random_tokens = mx.random.randint(
218
+ 0, self.tokenizer.vocab_size,
219
+ shape=input_ids.shape
220
+ )
221
+
222
+ # Apply the [MASK] token
223
+ inputs = mx.where(random_indices, random_tokens, masked_inputs)
224
+
225
+ batch["input_ids"] = inputs
226
+ batch["labels"] = labels
227
+
228
+ return dict(batch)
229
+
230
+ @dataclass
231
+ class DataCollatorForSentenceSimilarity(DataCollator):
232
+ """
233
+ Handles data for Bi-Encoder models (Sentence Similarity / Retrieval).
234
+ Unlike SequenceClassification, this keeps sentences SEPARATE to produce
235
+ independent embeddings.
236
+
237
+ Expected keys in features (from datasets.py standardization):
238
+ - 'text': The Anchor / Sentence A
239
+ - 'text_pair': The Positive / Reference / Sentence B
240
+ - 'negative' (optional): The Hard Negative / Sentence C
241
+ - 'label' (optional): Similarity score for Regression
242
+ """
243
+ def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
244
+ batch = {}
245
+
246
+ # Tokenize Anchor (Sentence A) -> 'input_ids'
247
+ if "text" in features:
248
+ out_a = self.tokenizer(
249
+ features["text"],
250
+ padding=True,
251
+ truncation=True,
252
+ max_length=self.max_length,
253
+ return_tensors="mlx"
254
+ )
255
+ batch["input_ids"] = out_a["input_ids"]
256
+ batch["attention_mask"] = out_a["attention_mask"]
257
+
258
+ # Tokenize Reference (Sentence B) -> 'reference_input_ids'
259
+ if "text_pair" in features:
260
+ out_b = self.tokenizer(
261
+ features["text_pair"],
262
+ padding=True,
263
+ truncation=True,
264
+ max_length=self.max_length,
265
+ return_tensors="mlx"
266
+ )
267
+ batch["reference_input_ids"] = out_b["input_ids"]
268
+ batch["reference_attention_mask"] = out_b["attention_mask"]
269
+
270
+ # Tokenize Negative (Sentence C) -> 'negative_input_ids'
271
+ neg_key = None
272
+ if "negative" in features: neg_key = "negative"
273
+ elif "text_negative" in features: neg_key = "text_negative"
274
+
275
+ if neg_key:
276
+ out_n = self.tokenizer(
277
+ features[neg_key],
278
+ padding=True,
279
+ truncation=True,
280
+ max_length=self.max_length,
281
+ return_tensors="mlx"
282
+ )
283
+ batch["negative_input_ids"] = out_n["input_ids"]
284
+ batch["negative_attention_mask"] = out_n["attention_mask"]
285
+
286
+ # Handle Scores (for Regression)
287
+ if "label" in features:
288
+ # Ensure float32 for regression targets
289
+ batch["similarity_scores"] = mx.array(features["label"], dtype=mx.float32)
290
+
291
+ return batch