mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
mlx_raclate/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
# RACLATE Tuner
|
|
2
|
+
|
|
3
|
+
The `tuner` module is the training engine of RACLATE (**R**etrieval **A**nd **C**lassification including **LATE** interaction models). Built entirely on Apple's [MLX](https://github.com/ml-explore/mlx) framework, it provides a highly efficient, unified interface for fine-tuning *small* Transformer-based classifiers on Apple Silicon.
|
|
4
|
+
|
|
5
|
+
This trainer supports standard dense retrieval, classification, and masked language modeling, as well as **Late Interaction (ColBERT-style)** training patterns.
|
|
6
|
+
|
|
7
|
+
## Key Features
|
|
8
|
+
|
|
9
|
+
* **Apple Silicon Native:** Fully optimized for M-series chips using MLX.
|
|
10
|
+
* **Full Training:** Full fine-tuning of pretrained models (_see supported architectures below_). LORA fine-tuning is not supported (yet). The library allows transfer learning, meaning that existing heads can be stripped out of pretrained models (and new heads can be added to base models for specific tasks)
|
|
11
|
+
* **Memory Efficiency:** Built-in support for **Gradient Accumulation** and **Gradient Checkpointing** to train larger batches/models on limited Unified Memory.
|
|
12
|
+
* **Flexible Schedulers:** Linear, Cosine, and Constant learning rate schedules with warmup.
|
|
13
|
+
* **Smart Collators:** Task-specific data collators that handle padding, masking, and chat templates automatically.
|
|
14
|
+
* **Embedding Freezing:** Option to freeze embedding layers to speed up fine-tuning or prevent catastrophic forgetting.
|
|
15
|
+
* **HF Hub Integration (TODO):** Seamless saving and pushing of checkpoints to the Hugging Face Hub.
|
|
16
|
+
|
|
17
|
+
## Supported Architectures
|
|
18
|
+
|
|
19
|
+
The trainer supports a variety of modern architectures supporting long context (relative to BERT models). As these models are meant to be trained and run on local machines, model implementations are specifically optimized for small-to-mid-sized models:
|
|
20
|
+
|
|
21
|
+
* **ModernBERT**: MLX implementation of `answerdotai/ModernBERT-base` (encoder-only). Long context (8k) and high efficiency.
|
|
22
|
+
* **Qwen 3**: MLX implementation of `Qwen/Qwen3-Embedding-0.6B` (32k context window) which leverages the qwen3 architecture.
|
|
23
|
+
* **Gemma 3**: MLX implementation of `google/embeddinggemma-300m` (2k context window) which leverages the gemma3 text variant architecture with a few tweaks. As per the official embeddinggemma3 architecture, the attention mask is set to causal or bi-directional based on a config parameter (`use_bidirectional_attn` or `use_bidirectional_attention`). Therefore, it is possible to switch between encoder and decoder mode, and standard gemma3_text models (32k context window) are also supported.
|
|
24
|
+
* **T5Gemma-Encoder**: MLX implementation of `google/t5gemma-b-b-ul2`, but only keeping the encoder weights at initialization (the encoder config is merged into the main model config)
|
|
25
|
+
* **LFM2**: MLX implementation of `LiquidAI/LFM2-350M` (Causal/AR) which also supports `LiquidAI/LFM2-ColBERT-350M` when model config file includes `use_late_interaction=True`. These models have a context window of 128k tokens. In training mode, 128k tokens exceeds the RAM capacity of most Apple hardware. _See parameters below to cap sequences to a more reasonable length during training_
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
## Supported Tasks & Pipelines
|
|
29
|
+
|
|
30
|
+
The `Trainer` adapts its logic based on the `task_type` and the specific model class initialized.
|
|
31
|
+
|
|
32
|
+
### 1. Sentence Similarity (Embedding & Retrieval)
|
|
33
|
+
Train models for semantic search, clustering, or RAG.
|
|
34
|
+
* **Task Type:** `sentence-similarity`
|
|
35
|
+
* **Training Modes:**
|
|
36
|
+
* **Bi-Encoder (Dense):** Standard cosine similarity optimization.
|
|
37
|
+
* **Late Interaction (MaxSim):** ColBERT-style interaction where fine-grained token-level similarities are computed (requires `use_late_interaction=True`).
|
|
38
|
+
* **Loss Functions:** Automatically selects between **MNRL (Multiple Negatives Ranking Loss)** for triplets/pairs or **MSE/Cosine Loss** for scored pairs.
|
|
39
|
+
|
|
40
|
+
### 2. Sequence Classification
|
|
41
|
+
Train discriminative models for sentiment analysis, intent detection, etc.
|
|
42
|
+
* **Task Type:** `text-classification`
|
|
43
|
+
* **Features:**
|
|
44
|
+
* Supports Multi-class and Binary classification.
|
|
45
|
+
* Supports Regression (if `is_regression=True`).
|
|
46
|
+
* Native support for Chat Templates in tokenizer.
|
|
47
|
+
|
|
48
|
+
### 3. Masked Language Modeling (MLM)
|
|
49
|
+
Perform domain adaptation on raw text.
|
|
50
|
+
* **Task Type:** `masked-lm`
|
|
51
|
+
* **Features:** Implements the standard 80% mask / 10% random / 10% original masking strategy dynamically during training.
|
|
52
|
+
|
|
53
|
+
### 4. Token Classification (NER/POS)
|
|
54
|
+
Named Entity Recognition and Part-of-Speech tagging.
|
|
55
|
+
* **Task Type:** `token-classification`
|
|
56
|
+
* **Features:** Handles label alignment for sub-word tokens automatically.
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
## Data Preparation
|
|
60
|
+
|
|
61
|
+
The `datasets.py` module handles loading (JSONL, Parquet, CSV, HF Hub) and column standardization. It is built on top of HuggingFace's datasets.
|
|
62
|
+
|
|
63
|
+
### 1. Column Mapping
|
|
64
|
+
The trainer looks for specific column names.
|
|
65
|
+
|
|
66
|
+
| Task | Required Columns | Description |
|
|
67
|
+
| :--- | :--- | :--- |
|
|
68
|
+
| **Classification** | `text`, `label` | Input text and target class/score. |
|
|
69
|
+
| **Pairs (Sim.)** | `text`, `text_pair` | Anchor and Positive/Candidate. |
|
|
70
|
+
| **Triplets** | `text`, `text_pair`, `negative` | Anchor, Positive, Hard Negative. |
|
|
71
|
+
| **MLM** | `text` | Raw text for masking. |
|
|
72
|
+
| **NER** | `tokens`, `labels` | Pre-tokenized words and aligned tags. |
|
|
73
|
+
|
|
74
|
+
*Note: For Sentence Similarity, if a `label` column is present with floats, the trainer switches to Regression/MSE loss (e.g., for scored Bi-Encoders).*
|
|
75
|
+
|
|
76
|
+
You can map your custom dataset fields via `DatasetArgs`.
|
|
77
|
+
```
|
|
78
|
+
# Load datasets
|
|
79
|
+
dataset_args = DatasetArgs(
|
|
80
|
+
data=dataset, # dataset path
|
|
81
|
+
task_type=task_type,
|
|
82
|
+
text_field="question", # maps column 'question' to 'text'
|
|
83
|
+
text_pair_field="response", # maps column 'response' to 'text_pair'
|
|
84
|
+
negative_field="semantically_different_response", # maps column 'semantically_different_response' to 'negative'
|
|
85
|
+
label_field="classification", # maps column 'classification to 'label'
|
|
86
|
+
test=True # creates a test split, if not already present in the dataset, out of the training set (validation set not affected).
|
|
87
|
+
)
|
|
88
|
+
```
|
|
89
|
+
Anchor is automatically mapped to `text` and Positive is automatically mapped to `text_pair`. See _standardize_column_names() in `datasets.py` for more information on column mapping.
|
|
90
|
+
|
|
91
|
+
### 2. Text Pairs and Chat Template
|
|
92
|
+
|
|
93
|
+
For certain tasks like text-classification, you may want to classify how two token sequences (text and text_pair) relate to each other.
|
|
94
|
+
|
|
95
|
+
For bi-encoders, it is highly recommended to let the tokenizer combine the text and the text_pair rather than aggregating them manually. This ensures that the correct separation token is used.
|
|
96
|
+
|
|
97
|
+
```
|
|
98
|
+
batch = self.tokenizer(
|
|
99
|
+
texts,
|
|
100
|
+
text_pairs,
|
|
101
|
+
padding="longest",
|
|
102
|
+
truncation=True,
|
|
103
|
+
max_length=self.max_length,
|
|
104
|
+
return_tensors="mlx"
|
|
105
|
+
)
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
In some cases, you may want to use the chat template that was used to train the model you intend to finetune. For example, LFM2-350M recommends using a chat template.
|
|
109
|
+
|
|
110
|
+
If `use_chat_template` is set to True when initializing the training (default False) and if a chat template is available in the tokenizer (do check!), the text and the text_pair values will be combined and text_pair will be set to None.
|
|
111
|
+
|
|
112
|
+
You can also force a specific string as separator.
|
|
113
|
+
|
|
114
|
+
This is how it works under the hood:
|
|
115
|
+
|
|
116
|
+
```
|
|
117
|
+
if text_pairs is not None:
|
|
118
|
+
if getattr(self.tokenizer, "chat_template", None) and self.use_chat_template:
|
|
119
|
+
# This ensures the model sees exactly what it expects for Q&A
|
|
120
|
+
formatted_texts = []
|
|
121
|
+
for prompt, response in zip(texts, text_pairs):
|
|
122
|
+
messages = [
|
|
123
|
+
{"role": "user", "content": prompt},
|
|
124
|
+
{"role": "assistant", "content": response}
|
|
125
|
+
]
|
|
126
|
+
formatted_texts.append(
|
|
127
|
+
self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
128
|
+
)
|
|
129
|
+
texts = formatted_texts
|
|
130
|
+
text_pairs = None # Handled by template
|
|
131
|
+
|
|
132
|
+
elif self.force_separator is not None:
|
|
133
|
+
# Use the forced separator for decoder models
|
|
134
|
+
texts = [
|
|
135
|
+
f"{t}{self.force_separator}{p}"
|
|
136
|
+
for t, p in zip(texts, text_pairs)
|
|
137
|
+
]
|
|
138
|
+
text_pairs = None
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
See DataCollatorForSequenceClassification in `collators.py` for more information on text_pair handling for text-classification.
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
## Quick Start (Programmatic)
|
|
145
|
+
|
|
146
|
+
Below is a simplified example of how to set up a training run programmatically.
|
|
147
|
+
|
|
148
|
+
```python
|
|
149
|
+
from mlx_raclate.utils.utils import load
|
|
150
|
+
from mlx_raclate.tuner.datasets import load_dataset, DatasetArgs
|
|
151
|
+
from mlx_raclate.tuner.trainer import Trainer, TrainingArgs
|
|
152
|
+
|
|
153
|
+
# 1. Configuration variables
|
|
154
|
+
model_path = "Qwen/Qwen3-Embedding-0.6B"
|
|
155
|
+
dataset_path = "data/wines"
|
|
156
|
+
task_type = "text-classification"
|
|
157
|
+
|
|
158
|
+
# 2. Load and Prepare Dataset
|
|
159
|
+
dataset_args = DatasetArgs(
|
|
160
|
+
data=dataset_path,
|
|
161
|
+
task_type=task_type,
|
|
162
|
+
# Optional: override field names if your data isn't standard
|
|
163
|
+
# text_field="question",
|
|
164
|
+
# text_pair_field="response",
|
|
165
|
+
# label_field="classification"
|
|
166
|
+
)
|
|
167
|
+
train_ds, valid_ds, test_ds, id2label, label2id = load_dataset(dataset_args)
|
|
168
|
+
|
|
169
|
+
# 3. Load Model and Tokenizer
|
|
170
|
+
# Pass label mappings to model config for classification tasks
|
|
171
|
+
model_config = {"id2label": id2label, "label2id": label2id} if id2label else {}
|
|
172
|
+
|
|
173
|
+
model, tokenizer = load(
|
|
174
|
+
model_path,
|
|
175
|
+
model_config=model_config,
|
|
176
|
+
pipeline=task_type,
|
|
177
|
+
train=True
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# 4. Define Training Arguments
|
|
181
|
+
args = TrainingArgs(
|
|
182
|
+
output_dir="outputs/my_run",
|
|
183
|
+
batch_size=4,
|
|
184
|
+
gradient_accumulation_steps=4,
|
|
185
|
+
learning_rate=1e-5,
|
|
186
|
+
num_train_epochs=3,
|
|
187
|
+
lr_scheduler_type="cosine_decay",
|
|
188
|
+
warmup_ratio=0.03,
|
|
189
|
+
save_steps=500,
|
|
190
|
+
logging_steps=10,
|
|
191
|
+
max_length=2048,
|
|
192
|
+
freeze_embeddings=False
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# 5. Initialize Trainer
|
|
196
|
+
trainer = Trainer(
|
|
197
|
+
model=model,
|
|
198
|
+
tokenizer=tokenizer,
|
|
199
|
+
task_type=task_type,
|
|
200
|
+
training_args=args,
|
|
201
|
+
train_dataset=train_ds,
|
|
202
|
+
eval_dataset=valid_ds,
|
|
203
|
+
label2id=label2id,
|
|
204
|
+
# For decoder models doing classification on pairs:
|
|
205
|
+
use_chat_template=False
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# 6. Run Training
|
|
209
|
+
trainer.train()
|
|
210
|
+
|
|
211
|
+
# 7. Evaluate on Test Set (Optional)
|
|
212
|
+
if test_ds:
|
|
213
|
+
trainer.test(test_ds)
|
|
214
|
+
```
|
|
215
|
+
|
|
216
|
+
## CLI usage
|
|
217
|
+
|
|
218
|
+
An example of CLI tool including **all** parameters to train a model is available in `mlx_raclate.utils.train.py`.
|
|
219
|
+
|
|
220
|
+
WARNING : this example includes default values that override the default values of the DatasetArgs, TrainingArgs and Trainer classes presented below.
|
|
221
|
+
|
|
222
|
+
## API Reference
|
|
223
|
+
|
|
224
|
+
### DatasetArgs
|
|
225
|
+
|
|
226
|
+
Used to configure how data is loaded and mapped.
|
|
227
|
+
|
|
228
|
+
| Parameter | Type | Default | Description |
|
|
229
|
+
| :--- | :--- | :--- | :--- |
|
|
230
|
+
| `data` | `str` | *Required* | Local path or HF identifier of the dataset. |
|
|
231
|
+
| `task_type` | `str` | *Required* | The type of task (e.g., `text-classification`). |
|
|
232
|
+
| `text_field` | `str` | `None` | Name of the text input column. |
|
|
233
|
+
| `text_pair_field`| `str` | `None` | Name of the second text input column (for pairs). |
|
|
234
|
+
| `label_field` | `str` | `None` | Name of the label/target column. |
|
|
235
|
+
| `negative_field`| `str` | `None` | Name of the negative samples column. |
|
|
236
|
+
| `test` | `bool` | `False` | If True, creates a test split from the training set if one doesn't exist. |
|
|
237
|
+
|
|
238
|
+
Note : use load_dataset("dataset_path") from `datasets.py` to fetch the dataset splits and the label2id dictionary.
|
|
239
|
+
|
|
240
|
+
### TrainingArgs
|
|
241
|
+
|
|
242
|
+
Controls the hyperparameters and runtime configuration.
|
|
243
|
+
|
|
244
|
+
#### Hyperparameters
|
|
245
|
+
| Parameter | Default | Description |
|
|
246
|
+
| :--- | :--- | :--- |
|
|
247
|
+
| `batch_size` | `2` | The physical batch size per device/step. Reduce to 1 on Macbooks with limited RAM if training on long context. |
|
|
248
|
+
| `gradient_accumulation_steps` | `8` | Number of steps to accumulate gradients before updating weights. |
|
|
249
|
+
| `num_train_epochs` | `2` | Total number of training epochs. |
|
|
250
|
+
| `max_length` | `512` | Max sequence length. If `None`, uses model's default config. |
|
|
251
|
+
| `freeze_embeddings` | `False` | If `True`, freezes the embedding layer to save memory/compute. |
|
|
252
|
+
|
|
253
|
+
#### Optimizer & Scheduler
|
|
254
|
+
| Parameter | Default | Description |
|
|
255
|
+
| :--- | :--- | :--- |
|
|
256
|
+
| `learning_rate` | `3e-5` | Initial learning rate (Peak LR). |
|
|
257
|
+
| `weight_decay` | `0.01` | Weight decay factor for AdamW. |
|
|
258
|
+
| `lr_scheduler_type` | `"constant"` | Scheduler type: `"cosine_decay"`, `"linear_schedule"`, or `"constant"`. |
|
|
259
|
+
| `min_lr` | `0.0` | Minimum learning rate at the end of the schedule. |
|
|
260
|
+
| `warmup_ratio` | `0.0` | Ratio of total training steps used for warmup. |
|
|
261
|
+
| `warmup_steps` | `0` | Absolute number of warmup steps (overrides `warmup_ratio` if set). |
|
|
262
|
+
| `max_grad_norm` | `1.0` | Gradient clipping threshold. |
|
|
263
|
+
|
|
264
|
+
#### Checkpointing & Logging
|
|
265
|
+
| Parameter | Default | Description |
|
|
266
|
+
| :--- | :--- | :--- |
|
|
267
|
+
| `output_dir` | `None` | Directory to save checkpoints and logs. Defaults to a timestamped folder. |
|
|
268
|
+
| `save_steps` | `1000` | Frequency of saving model checkpoints (in steps). |
|
|
269
|
+
| `logging_steps` | `16` | Frequency of logging metrics to console/files. |
|
|
270
|
+
| `eval_batch_size` | `4` | Batch size used during evaluation/testing. |
|
|
271
|
+
| `resume_from_step`| `0` | Step to resume training from. If this is after the last warmup step (either declared or calculated via warmup_ratio), warmup will be ignored. |
|
|
272
|
+
|
|
273
|
+
Gradient checkpointing is enabled by default due to RAM constraints of consumer hardware.
|
|
274
|
+
|
|
275
|
+
### Model Config
|
|
276
|
+
|
|
277
|
+
When loading a pretrained model, you can create a model_config dictionary with new parameters and pass it to the load() function. Common examples :
|
|
278
|
+
|
|
279
|
+
| Parameter | Type | Default | Description |
|
|
280
|
+
| :--- | :--- | :--- | :--- |
|
|
281
|
+
| `is_regression` | `bool` | `False` | For text-classification tasks, whether the classification is a regression |
|
|
282
|
+
| `use_late_interaction`| `bool` | `False` | For sentence similarity tasks, whether late interaction (MaxSim) should be used instead of cosine similarity |
|
|
283
|
+
|
|
284
|
+
## Trainer
|
|
285
|
+
|
|
286
|
+
The main class that orchestrates the training.
|
|
287
|
+
|
|
288
|
+
**Constructor Parameters:**
|
|
289
|
+
|
|
290
|
+
* **`model`**: The loaded MLX model.
|
|
291
|
+
* **`tokenizer`**: The loaded tokenizer. If you want to use a chat template, make sure that the tokenizer includes the chat template. If not, add it manually before instantiating the Trainer.
|
|
292
|
+
* **`task_type`**: String identifier for the pipeline (e.g., "text-classification").
|
|
293
|
+
* **`training_args`**: Instance of `TrainingArgs`.
|
|
294
|
+
* **`train_dataset`**: The processed training dataset.
|
|
295
|
+
* **`eval_dataset`**: (Optional) The processed validation dataset.
|
|
296
|
+
* **`label2id`**: (Optional) Dictionary mapping labels to IDs (required for classification metrics).
|
|
297
|
+
* **`use_chat_template`** *(bool)*: If `True`, applies the tokenizer's chat template to inputs. Useful for decoder models (like Qwen/Llama) performing classification on text pairs.
|
|
298
|
+
* **`force_separator`** (Optional *str*): If not using a chat template, this string is used to join text pairs for decoder models.
|
|
299
|
+
* **`optimizer`** (Optional *mlx.optimizer*): If no optimizer is passed, AdamW will be used with the hyper parameters set in TrainingArgs
|
|
300
|
+
|
|
301
|
+
**Methods:**
|
|
302
|
+
|
|
303
|
+
* `train()`: Starts the training loop.
|
|
304
|
+
* `test(dataset)`: Runs evaluation on the provided dataset.
|
|
305
|
+
|
|
File without changes
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import List, Dict, Any, Optional, Union
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class DataCollator:
|
|
8
|
+
tokenizer: Any
|
|
9
|
+
max_length: int = 512
|
|
10
|
+
|
|
11
|
+
def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class DataCollatorForSequenceClassification(DataCollator):
|
|
16
|
+
"""
|
|
17
|
+
Handles tokenization and padding for classification tasks.
|
|
18
|
+
"""
|
|
19
|
+
use_chat_template: bool = False# Whether to use chat templates for decoder models
|
|
20
|
+
force_separator: Optional[str] = None # If set, forces this separator between text pairs
|
|
21
|
+
default_decoder_separator: str = "\n" # Used for decoder models when concatenating text pairs
|
|
22
|
+
label2id: Optional[Dict[str, int]] = None
|
|
23
|
+
|
|
24
|
+
def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
|
|
25
|
+
texts = features.get("text")
|
|
26
|
+
text_pairs = features.get("text_pair", None)
|
|
27
|
+
|
|
28
|
+
if text_pairs is not None:
|
|
29
|
+
if getattr(self.tokenizer, "chat_template", None) and self.use_chat_template:
|
|
30
|
+
# This ensures the model sees exactly what it expects for Q&A
|
|
31
|
+
formatted_texts = []
|
|
32
|
+
for prompt, response in zip(texts, text_pairs):
|
|
33
|
+
messages = [
|
|
34
|
+
{"role": "user", "content": prompt},
|
|
35
|
+
{"role": "assistant", "content": response}
|
|
36
|
+
]
|
|
37
|
+
formatted_texts.append(
|
|
38
|
+
self.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
39
|
+
)
|
|
40
|
+
texts = formatted_texts
|
|
41
|
+
text_pairs = None # Handled by template
|
|
42
|
+
|
|
43
|
+
elif self.force_separator is not None:
|
|
44
|
+
# Use the forced separator for decoder models
|
|
45
|
+
texts = [
|
|
46
|
+
f"{t}{self.force_separator}{p}"
|
|
47
|
+
for t, p in zip(texts, text_pairs)
|
|
48
|
+
]
|
|
49
|
+
text_pairs = None
|
|
50
|
+
|
|
51
|
+
else :
|
|
52
|
+
# Check if tokenizer has a standard separator (Like [SEP] in BERT)
|
|
53
|
+
# Qwen tokenizer often has sep_token as None or same as EOS
|
|
54
|
+
has_sep_token = getattr(self.tokenizer, "sep_token", None) is not None
|
|
55
|
+
|
|
56
|
+
if not has_sep_token or self.tokenizer.sep_token == self.tokenizer.eos_token:
|
|
57
|
+
texts = [
|
|
58
|
+
f"{t}{self.default_decoder_separator}{p}"
|
|
59
|
+
for t, p in zip(texts, text_pairs)
|
|
60
|
+
]
|
|
61
|
+
# Set pairs to None so tokenizer treats it as a single string
|
|
62
|
+
text_pairs = None
|
|
63
|
+
|
|
64
|
+
if self.tokenizer.pad_token_id is None:
|
|
65
|
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
66
|
+
|
|
67
|
+
batch = self.tokenizer(
|
|
68
|
+
texts,
|
|
69
|
+
text_pairs,
|
|
70
|
+
padding="longest",
|
|
71
|
+
truncation=True,
|
|
72
|
+
max_length=self.max_length,
|
|
73
|
+
return_tensors="mlx"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if "label" in features:
|
|
77
|
+
labels = features["label"]
|
|
78
|
+
# On-the-fly String to ID conversion
|
|
79
|
+
if self.label2id and len(labels) > 0 and isinstance(labels[0], str):
|
|
80
|
+
labels = [self.label2id.get(l, -1) for l in labels] # Default to -1 if missing
|
|
81
|
+
|
|
82
|
+
# Detect regression (float) vs classification (int)
|
|
83
|
+
if len(labels) > 0 and isinstance(labels[0], float):
|
|
84
|
+
dtype = mx.float32
|
|
85
|
+
else:
|
|
86
|
+
dtype = mx.int32
|
|
87
|
+
|
|
88
|
+
batch["labels"] = mx.array(labels, dtype=dtype)
|
|
89
|
+
|
|
90
|
+
return dict(batch)
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class DataCollatorForTokenClassification(DataCollator):
|
|
94
|
+
"""
|
|
95
|
+
Handles tokenization and aligns labels for token classification.
|
|
96
|
+
"""
|
|
97
|
+
label_pad_token_id: int = -100
|
|
98
|
+
# Strategy: 'first' (label only first subword), 'all' (label all subwords with same tag)
|
|
99
|
+
label_all_tokens: bool = False
|
|
100
|
+
label2id: Optional[Dict[str, int]] = None
|
|
101
|
+
|
|
102
|
+
def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
|
|
103
|
+
texts = features["text"]
|
|
104
|
+
labels = features["labels"] # Note: usually plural 'labels' list of list
|
|
105
|
+
|
|
106
|
+
# SANITY CHECK: The library expects pre-tokenized inputs (List[str])
|
|
107
|
+
if isinstance(texts[0], str):
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"DataCollatorForTokenClassification expects 'text' to be a list of strings "
|
|
110
|
+
"(tokens), not a single string. Please pre-tokenize your dataset."
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
batch = self.tokenizer(
|
|
114
|
+
texts,
|
|
115
|
+
padding=True,
|
|
116
|
+
truncation=True,
|
|
117
|
+
max_length=self.max_length,
|
|
118
|
+
return_tensors="mlx",
|
|
119
|
+
is_split_into_words=True
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
batch_size, seq_len = batch["input_ids"].shape
|
|
123
|
+
|
|
124
|
+
# Create a numpy buffer filled with the ignore index
|
|
125
|
+
padded_labels = np.full((batch_size, seq_len), self.label_pad_token_id, dtype=np.int32)
|
|
126
|
+
|
|
127
|
+
for i, label_seq in enumerate(labels):
|
|
128
|
+
# On-the-fly conversion for list of strings (to avoid memory issues with dataset.map)
|
|
129
|
+
current_labels = label_seq
|
|
130
|
+
if self.label2id and len(label_seq) > 0 and isinstance(label_seq[0], str):
|
|
131
|
+
current_labels = [self.label2id.get(l, self.label_pad_token_id) for l in label_seq]
|
|
132
|
+
|
|
133
|
+
# word_ids returns a list mapping each token to its original word index
|
|
134
|
+
# e.g., [None, 0, 1, 1, 2, None] for "[CLS] My name is John [SEP]"
|
|
135
|
+
word_ids = batch.word_ids(batch_index=i)
|
|
136
|
+
previous_word_idx = None
|
|
137
|
+
|
|
138
|
+
for k, word_idx in enumerate(word_ids):
|
|
139
|
+
# Skip Special Tokens (None)
|
|
140
|
+
if word_idx is None:
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
# Safety check: tokenizer truncation might leave word_ids that point to label indices larger than the label list provided.
|
|
144
|
+
if word_idx >= len(current_labels):
|
|
145
|
+
break
|
|
146
|
+
|
|
147
|
+
if word_idx != previous_word_idx:
|
|
148
|
+
padded_labels[i, k] = current_labels[word_idx]
|
|
149
|
+
else:
|
|
150
|
+
# This is a subsequent subword of the same word
|
|
151
|
+
if self.label_all_tokens:
|
|
152
|
+
padded_labels[i, k] = current_labels[word_idx]
|
|
153
|
+
else:
|
|
154
|
+
# Standard BERT NER behavior: ignore subsequent subwords
|
|
155
|
+
padded_labels[i, k] = self.label_pad_token_id
|
|
156
|
+
|
|
157
|
+
previous_word_idx = word_idx
|
|
158
|
+
|
|
159
|
+
batch["labels"] = mx.array(padded_labels, dtype=mx.int32)
|
|
160
|
+
|
|
161
|
+
return dict(batch)
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class DataCollatorForMaskedLanguageModeling(DataCollator):
|
|
165
|
+
"""
|
|
166
|
+
Handles dynamic masking for MLM.
|
|
167
|
+
"""
|
|
168
|
+
mlm_probability: float = 0.15
|
|
169
|
+
mask_token_id: Optional[int] = None
|
|
170
|
+
|
|
171
|
+
def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
|
|
172
|
+
texts = features["text"]
|
|
173
|
+
|
|
174
|
+
batch = self.tokenizer(
|
|
175
|
+
texts,
|
|
176
|
+
padding=True,
|
|
177
|
+
truncation=True,
|
|
178
|
+
max_length=self.max_length,
|
|
179
|
+
return_tensors="mlx"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
input_ids = batch["input_ids"]
|
|
183
|
+
|
|
184
|
+
# Create Mask
|
|
185
|
+
probability_matrix = mx.random.uniform(shape=input_ids.shape) < self.mlm_probability
|
|
186
|
+
|
|
187
|
+
# Protect special tokens
|
|
188
|
+
special_tokens_mask = mx.array([
|
|
189
|
+
[1 if token_id in [self.tokenizer.cls_token_id, self.tokenizer.sep_token_id, self.tokenizer.pad_token_id]
|
|
190
|
+
else 0 for token_id in seq]
|
|
191
|
+
for seq in input_ids.tolist()
|
|
192
|
+
])
|
|
193
|
+
|
|
194
|
+
probability_matrix = mx.where(special_tokens_mask, 0, probability_matrix)
|
|
195
|
+
|
|
196
|
+
# Create labels (-100 for unmasked)
|
|
197
|
+
labels = mx.where(probability_matrix, input_ids, -100)
|
|
198
|
+
|
|
199
|
+
# Apply masking (80% mask, 10% random, 10% original)
|
|
200
|
+
random_matrix = mx.random.uniform(shape=input_ids.shape)
|
|
201
|
+
mask_indices = (probability_matrix) & (random_matrix < 0.8)
|
|
202
|
+
random_indices = (probability_matrix) & (random_matrix >= 0.8) & (random_matrix < 0.9)
|
|
203
|
+
|
|
204
|
+
# Create masked input
|
|
205
|
+
masked_inputs = input_ids
|
|
206
|
+
|
|
207
|
+
mask_token_id = self.tokenizer.mask_token_id
|
|
208
|
+
if mask_token_id is None:
|
|
209
|
+
if self.mask_token_id is not None:
|
|
210
|
+
mask_token_id = self.mask_token_id
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"Tokenizer does not have a mask token defined and no mask_token_id provided."
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
masked_inputs = mx.where(mask_indices, mask_token_id, masked_inputs)
|
|
217
|
+
random_tokens = mx.random.randint(
|
|
218
|
+
0, self.tokenizer.vocab_size,
|
|
219
|
+
shape=input_ids.shape
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Apply the [MASK] token
|
|
223
|
+
inputs = mx.where(random_indices, random_tokens, masked_inputs)
|
|
224
|
+
|
|
225
|
+
batch["input_ids"] = inputs
|
|
226
|
+
batch["labels"] = labels
|
|
227
|
+
|
|
228
|
+
return dict(batch)
|
|
229
|
+
|
|
230
|
+
@dataclass
|
|
231
|
+
class DataCollatorForSentenceSimilarity(DataCollator):
|
|
232
|
+
"""
|
|
233
|
+
Handles data for Bi-Encoder models (Sentence Similarity / Retrieval).
|
|
234
|
+
Unlike SequenceClassification, this keeps sentences SEPARATE to produce
|
|
235
|
+
independent embeddings.
|
|
236
|
+
|
|
237
|
+
Expected keys in features (from datasets.py standardization):
|
|
238
|
+
- 'text': The Anchor / Sentence A
|
|
239
|
+
- 'text_pair': The Positive / Reference / Sentence B
|
|
240
|
+
- 'negative' (optional): The Hard Negative / Sentence C
|
|
241
|
+
- 'label' (optional): Similarity score for Regression
|
|
242
|
+
"""
|
|
243
|
+
def __call__(self, features: Dict[str, List[Any]]) -> Dict[str, mx.array]:
|
|
244
|
+
batch = {}
|
|
245
|
+
|
|
246
|
+
# Tokenize Anchor (Sentence A) -> 'input_ids'
|
|
247
|
+
if "text" in features:
|
|
248
|
+
out_a = self.tokenizer(
|
|
249
|
+
features["text"],
|
|
250
|
+
padding=True,
|
|
251
|
+
truncation=True,
|
|
252
|
+
max_length=self.max_length,
|
|
253
|
+
return_tensors="mlx"
|
|
254
|
+
)
|
|
255
|
+
batch["input_ids"] = out_a["input_ids"]
|
|
256
|
+
batch["attention_mask"] = out_a["attention_mask"]
|
|
257
|
+
|
|
258
|
+
# Tokenize Reference (Sentence B) -> 'reference_input_ids'
|
|
259
|
+
if "text_pair" in features:
|
|
260
|
+
out_b = self.tokenizer(
|
|
261
|
+
features["text_pair"],
|
|
262
|
+
padding=True,
|
|
263
|
+
truncation=True,
|
|
264
|
+
max_length=self.max_length,
|
|
265
|
+
return_tensors="mlx"
|
|
266
|
+
)
|
|
267
|
+
batch["reference_input_ids"] = out_b["input_ids"]
|
|
268
|
+
batch["reference_attention_mask"] = out_b["attention_mask"]
|
|
269
|
+
|
|
270
|
+
# Tokenize Negative (Sentence C) -> 'negative_input_ids'
|
|
271
|
+
neg_key = None
|
|
272
|
+
if "negative" in features: neg_key = "negative"
|
|
273
|
+
elif "text_negative" in features: neg_key = "text_negative"
|
|
274
|
+
|
|
275
|
+
if neg_key:
|
|
276
|
+
out_n = self.tokenizer(
|
|
277
|
+
features[neg_key],
|
|
278
|
+
padding=True,
|
|
279
|
+
truncation=True,
|
|
280
|
+
max_length=self.max_length,
|
|
281
|
+
return_tensors="mlx"
|
|
282
|
+
)
|
|
283
|
+
batch["negative_input_ids"] = out_n["input_ids"]
|
|
284
|
+
batch["negative_attention_mask"] = out_n["attention_mask"]
|
|
285
|
+
|
|
286
|
+
# Handle Scores (for Regression)
|
|
287
|
+
if "label" in features:
|
|
288
|
+
# Ensure float32 for regression targets
|
|
289
|
+
batch["similarity_scores"] = mx.array(features["label"], dtype=mx.float32)
|
|
290
|
+
|
|
291
|
+
return batch
|