mlx-raclate 0.1.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_raclate/__init__.py +1 -0
- mlx_raclate/models/__init__.py +0 -0
- mlx_raclate/models/base.py +225 -0
- mlx_raclate/models/gemma3_text.py +913 -0
- mlx_raclate/models/lfm2.py +671 -0
- mlx_raclate/models/modernbert.py +900 -0
- mlx_raclate/models/qwen3.py +582 -0
- mlx_raclate/models/t5gemma_encoder.py +857 -0
- mlx_raclate/py.typed +0 -0
- mlx_raclate/tuner/TUNER.md +305 -0
- mlx_raclate/tuner/__init__.py +0 -0
- mlx_raclate/tuner/collators.py +291 -0
- mlx_raclate/tuner/datasets.py +247 -0
- mlx_raclate/tuner/model_card_utils.py +206 -0
- mlx_raclate/tuner/trainer.py +648 -0
- mlx_raclate/tuner/utils.py +292 -0
- mlx_raclate/utils/__init__.py +0 -0
- mlx_raclate/utils/server.py +390 -0
- mlx_raclate/utils/tokenizer_utils.py +353 -0
- mlx_raclate/utils/train.py +249 -0
- mlx_raclate/utils/utils.py +625 -0
- mlx_raclate-0.1.0b1.dist-info/METADATA +216 -0
- mlx_raclate-0.1.0b1.dist-info/RECORD +25 -0
- mlx_raclate-0.1.0b1.dist-info/WHEEL +4 -0
- mlx_raclate-0.1.0b1.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,648 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import json
|
|
3
|
+
import gc
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Dict
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import mlx.core as mx
|
|
11
|
+
import mlx.nn as nn
|
|
12
|
+
import mlx.optimizers
|
|
13
|
+
from datasets import Dataset as HFDataset
|
|
14
|
+
|
|
15
|
+
from mlx.utils import tree_flatten, tree_map
|
|
16
|
+
|
|
17
|
+
from .collators import DataCollator
|
|
18
|
+
from .utils import EMBEDDING_LAYER_NAMES, build_schedule
|
|
19
|
+
from mlx_raclate.tuner.model_card_utils import get_code_for_trained_model
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class TrainingArgs:
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
batch_size: int = 2,
|
|
27
|
+
eval_batch_size: int = 4,
|
|
28
|
+
max_length: int = 512,
|
|
29
|
+
resume_from_step: int = 0,
|
|
30
|
+
num_train_epochs: int = 2,
|
|
31
|
+
learning_rate: float = 3e-5,
|
|
32
|
+
weight_decay: float = 0.01,
|
|
33
|
+
freeze_embeddings: bool = False,
|
|
34
|
+
warmup_ratio: float = 0,
|
|
35
|
+
warmup_steps: int = 0, # warmup steps take precedence over warmup ratio, warmup_steps are optimizer steps (dataset size / (batch_size * grad_accumulation))
|
|
36
|
+
lr_scheduler_type: str = "constant", # "cosine_decay", "linear_schedule", https://ml-explore.github.io/mlx/build/html/python/optimizers/schedulers.html
|
|
37
|
+
min_lr: float = 0.0, # minimum learning rate for schedulers that need it
|
|
38
|
+
gradient_accumulation_steps: int = 8,
|
|
39
|
+
max_grad_norm: float = 1,
|
|
40
|
+
save_steps: int = 1000,
|
|
41
|
+
logging_steps: int = 100,
|
|
42
|
+
output_dir: str = "outputs",
|
|
43
|
+
save_total_limit: Optional[int] = None,
|
|
44
|
+
grad_checkpoint: bool = True,
|
|
45
|
+
push_to_hub: bool = False,
|
|
46
|
+
):
|
|
47
|
+
self.batch_size = batch_size
|
|
48
|
+
self.eval_batch_size = eval_batch_size
|
|
49
|
+
self.max_length = max_length
|
|
50
|
+
self.resume_from_step = resume_from_step
|
|
51
|
+
self.num_train_epochs = num_train_epochs
|
|
52
|
+
self.learning_rate = learning_rate
|
|
53
|
+
self.weight_decay = weight_decay
|
|
54
|
+
self.freeze_embeddings = freeze_embeddings
|
|
55
|
+
self.warmup_ratio = warmup_ratio
|
|
56
|
+
self.warmup_steps = warmup_steps
|
|
57
|
+
self.lr_scheduler_type = lr_scheduler_type
|
|
58
|
+
self.min_lr = min_lr
|
|
59
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
60
|
+
self.max_grad_norm = max_grad_norm
|
|
61
|
+
self.save_steps = save_steps
|
|
62
|
+
self.logging_steps = logging_steps
|
|
63
|
+
self.output_dir = output_dir
|
|
64
|
+
self.save_total_limit = save_total_limit
|
|
65
|
+
self.grad_checkpoint = grad_checkpoint ### mat not be necessary but helps anticipating hardware constraints
|
|
66
|
+
self.push_to_hub = push_to_hub
|
|
67
|
+
|
|
68
|
+
class Trainer:
|
|
69
|
+
"""
|
|
70
|
+
A trainer that adapts to the model's training objective.
|
|
71
|
+
The training logic is determined by the model's class implementation.
|
|
72
|
+
|
|
73
|
+
TODO : add basemodel and upload repo arguments to upload to HF hub
|
|
74
|
+
"""
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
model: nn.Module,
|
|
78
|
+
tokenizer,
|
|
79
|
+
task_type: str,
|
|
80
|
+
training_args: TrainingArgs,
|
|
81
|
+
train_dataset: HFDataset,
|
|
82
|
+
use_chat_template: bool = False, # for decoder-based models, you may want to use chat templates when preparing the data
|
|
83
|
+
force_separator: Optional[str] = None, # for decoder-based models, you may want to force a specific separator when preparing the data
|
|
84
|
+
eval_dataset: Optional[HFDataset] = None,
|
|
85
|
+
optimizer = None,
|
|
86
|
+
label2id: Optional[Dict[str, int]] = None
|
|
87
|
+
):
|
|
88
|
+
self.model = model
|
|
89
|
+
self.tokenizer = tokenizer._tokenizer ### tokenizer is a wrapper around the HF tokenizer (see utils/tokenizer_utils.py)
|
|
90
|
+
self.task_type = task_type
|
|
91
|
+
|
|
92
|
+
self.args = training_args
|
|
93
|
+
# Adjust logging and saving steps based on gradient accumulation
|
|
94
|
+
if training_args.logging_steps % training_args.gradient_accumulation_steps != 0:
|
|
95
|
+
closest_multiple = (training_args.logging_steps // training_args.gradient_accumulation_steps) * training_args.gradient_accumulation_steps
|
|
96
|
+
self.logging_steps = closest_multiple if closest_multiple > 0 else training_args.gradient_accumulation_steps
|
|
97
|
+
else:
|
|
98
|
+
self.logging_steps = training_args.logging_steps
|
|
99
|
+
if training_args.save_steps % self.logging_steps != 0:
|
|
100
|
+
closest_multiple = (training_args.save_steps // self.logging_steps ) * self.logging_steps
|
|
101
|
+
self.save_steps = closest_multiple if closest_multiple > 0 else self.logging_steps
|
|
102
|
+
else:
|
|
103
|
+
self.save_steps = training_args.save_steps
|
|
104
|
+
|
|
105
|
+
self.resume_from_step = training_args.resume_from_step
|
|
106
|
+
# TODO : handle resuming from checkpoint (load model + optimizer state)
|
|
107
|
+
# For now, no optimizer state loading
|
|
108
|
+
|
|
109
|
+
self.train_dataset = train_dataset
|
|
110
|
+
self.use_chat_template = use_chat_template
|
|
111
|
+
self.force_separator = force_separator
|
|
112
|
+
self.eval_dataset = eval_dataset
|
|
113
|
+
self.label2id = label2id
|
|
114
|
+
self.data_collator = self._get_collator()
|
|
115
|
+
|
|
116
|
+
if training_args.freeze_embeddings:
|
|
117
|
+
print("Freezing embedding layers.")
|
|
118
|
+
if model.config.model_type in EMBEDDING_LAYER_NAMES:
|
|
119
|
+
model.model.freeze(keys=EMBEDDING_LAYER_NAMES[model.config.model_type])
|
|
120
|
+
else:
|
|
121
|
+
print(f"Warning: No embedding layer names defined for model type {model.config.model_type}. Using common names (embed_tokens, embeddings).")
|
|
122
|
+
model.model.freeze(keys=["embed_tokens", "embeddings"])
|
|
123
|
+
|
|
124
|
+
# Initialize optimizer
|
|
125
|
+
if optimizer is not None:
|
|
126
|
+
self.optimizer = optimizer
|
|
127
|
+
elif training_args.lr_scheduler_type=="constant" and not (training_args.warmup_steps or training_args.warmup_ratio):
|
|
128
|
+
self.optimizer = mlx.optimizers.AdamW(
|
|
129
|
+
learning_rate=training_args.learning_rate,
|
|
130
|
+
weight_decay=training_args.weight_decay
|
|
131
|
+
)
|
|
132
|
+
else:
|
|
133
|
+
# Build learning rate schedule
|
|
134
|
+
steps_per_epoch = len(train_dataset) // training_args.batch_size
|
|
135
|
+
if len(train_dataset) % training_args.batch_size != 0:
|
|
136
|
+
steps_per_epoch += 1
|
|
137
|
+
|
|
138
|
+
# Effective steps considering gradient accumulation
|
|
139
|
+
num_update_steps_per_epoch = max(steps_per_epoch // training_args.gradient_accumulation_steps, 1)
|
|
140
|
+
resumed_update_steps = self.resume_from_step // training_args.gradient_accumulation_steps
|
|
141
|
+
total_update_steps = num_update_steps_per_epoch * training_args.num_train_epochs
|
|
142
|
+
if resumed_update_steps >= total_update_steps:
|
|
143
|
+
raise ValueError("resume_from_step is greater than total training steps. Steps = dataset_size / batch_size * num_epochs")
|
|
144
|
+
max_steps = max(total_update_steps - resumed_update_steps, 0)
|
|
145
|
+
|
|
146
|
+
if training_args.warmup_steps > 0:
|
|
147
|
+
warmup_steps = training_args.warmup_steps
|
|
148
|
+
else:
|
|
149
|
+
warmup_steps = int(max_steps * training_args.warmup_ratio)
|
|
150
|
+
|
|
151
|
+
if self.resume_from_step and warmup_steps <= (self.resume_from_step// training_args.gradient_accumulation_steps):
|
|
152
|
+
warmup_steps = 0
|
|
153
|
+
|
|
154
|
+
decay_steps = max_steps - warmup_steps
|
|
155
|
+
|
|
156
|
+
scheduler_type = training_args.lr_scheduler_type # e.g. "constant", "cosine_decay"
|
|
157
|
+
|
|
158
|
+
# Arguments list depends on the function signature in mlx.optimizers
|
|
159
|
+
if scheduler_type == "constant":
|
|
160
|
+
schedule_args = [training_args.learning_rate]
|
|
161
|
+
|
|
162
|
+
elif scheduler_type == "linear_schedule":
|
|
163
|
+
schedule_args = [training_args.learning_rate, training_args.min_lr if training_args.min_lr else 0.0, decay_steps]
|
|
164
|
+
|
|
165
|
+
elif scheduler_type == "cosine_decay":
|
|
166
|
+
schedule_args = [training_args.learning_rate, decay_steps, training_args.min_lr if training_args.min_lr else 0.0]
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(f"Unsupported lr_scheduler_type: {scheduler_type}")
|
|
169
|
+
|
|
170
|
+
print(f"Scheduler: {scheduler_type} | Warmup: {warmup_steps} | Total: {max_steps}")
|
|
171
|
+
|
|
172
|
+
schedule_config = {
|
|
173
|
+
"name": scheduler_type,
|
|
174
|
+
"arguments": schedule_args,
|
|
175
|
+
"warmup_steps": warmup_steps,
|
|
176
|
+
"warmup_init": 0.0
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
lr_schedule = build_schedule(schedule_config)
|
|
180
|
+
|
|
181
|
+
self.optimizer = mlx.optimizers.AdamW(
|
|
182
|
+
learning_rate=lr_schedule,
|
|
183
|
+
weight_decay=training_args.weight_decay
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Setup output directory
|
|
187
|
+
self.output_dir = Path("trained_models") / training_args.output_dir
|
|
188
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
189
|
+
|
|
190
|
+
# Setup training state and output directory
|
|
191
|
+
self.global_step = 0
|
|
192
|
+
self.epoch = 0
|
|
193
|
+
self.next_save_step = self.resume_from_step + self.save_steps
|
|
194
|
+
self.next_log_step = self.resume_from_step + self.logging_steps
|
|
195
|
+
|
|
196
|
+
# Capture state that needs updating (random state for Dropout, etc.)
|
|
197
|
+
self.state = [self.model.state, self.optimizer.state, mx.random.state]
|
|
198
|
+
|
|
199
|
+
# Enable gradient checkpointing if requested
|
|
200
|
+
if training_args.grad_checkpoint:
|
|
201
|
+
self._apply_grad_checkpointing()
|
|
202
|
+
|
|
203
|
+
def loss_fn(model, batch):
|
|
204
|
+
outputs = model(**batch)
|
|
205
|
+
return mx.mean(outputs["loss"])
|
|
206
|
+
|
|
207
|
+
grad_fn = nn.value_and_grad(self.model, loss_fn)
|
|
208
|
+
|
|
209
|
+
@partial(mx.compile, inputs=self.state, outputs=self.state)
|
|
210
|
+
def step_calc(batch):
|
|
211
|
+
loss, grads = grad_fn(self.model, batch)
|
|
212
|
+
return loss, grads
|
|
213
|
+
|
|
214
|
+
self.step_calc = step_calc
|
|
215
|
+
|
|
216
|
+
# Optimizer Update Function
|
|
217
|
+
# We define a function that takes the model and ACCUMULATED grads
|
|
218
|
+
@partial(mx.compile, inputs=self.state, outputs=self.state)
|
|
219
|
+
def update_fn(accumulated_grads):
|
|
220
|
+
# Flatten gradients to compute norm
|
|
221
|
+
flattened_grads = tree_flatten(accumulated_grads)
|
|
222
|
+
|
|
223
|
+
squares = [mx.sum(mx.square(g[1])) for g in flattened_grads]
|
|
224
|
+
total_norm = mx.sqrt(mx.sum(mx.array(squares)))
|
|
225
|
+
|
|
226
|
+
# Conputing clipping coeff
|
|
227
|
+
clip_coeff = training_args.max_grad_norm / (total_norm + 1e-6)
|
|
228
|
+
scale = mx.minimum(1.0, clip_coeff)
|
|
229
|
+
|
|
230
|
+
# Gradient clipping
|
|
231
|
+
accumulated_grads = tree_map(lambda g: g * scale, accumulated_grads)
|
|
232
|
+
|
|
233
|
+
self.optimizer.update(self.model, accumulated_grads)
|
|
234
|
+
|
|
235
|
+
return total_norm
|
|
236
|
+
|
|
237
|
+
self.step_update = update_fn
|
|
238
|
+
self.push_to_hub = training_args.push_to_hub
|
|
239
|
+
|
|
240
|
+
print(f"Training {model.__class__.__name__}")
|
|
241
|
+
# Log model type and config
|
|
242
|
+
self._save_config()
|
|
243
|
+
|
|
244
|
+
def _apply_grad_checkpointing(self):
|
|
245
|
+
"""
|
|
246
|
+
Apply gradient checkpointing to the model's forward pass to reduce memory usage.
|
|
247
|
+
Uses MLX's checkpoint mechanism to save memory during backpropagation.
|
|
248
|
+
"""
|
|
249
|
+
def checkpoint_fn(module):
|
|
250
|
+
original_call = module.__call__
|
|
251
|
+
|
|
252
|
+
def checkpointed_call(self, **kwargs):
|
|
253
|
+
# Let MLX handle the parameter management, just checkpoint the function call
|
|
254
|
+
return mx.checkpoint(original_call)(self, **kwargs)
|
|
255
|
+
|
|
256
|
+
module.__call__ = checkpointed_call
|
|
257
|
+
|
|
258
|
+
layers = None
|
|
259
|
+
|
|
260
|
+
# Handling various model architectures
|
|
261
|
+
if hasattr(self.model, "layers"):
|
|
262
|
+
layers = self.model.layers
|
|
263
|
+
elif hasattr(self.model, "model"):
|
|
264
|
+
if hasattr(self.model.model, "layers"):
|
|
265
|
+
layers = self.model.model.layers
|
|
266
|
+
elif hasattr(self.model.model, "encoder"): # Others TBC
|
|
267
|
+
if hasattr(self.model.model.encoder, "layers"):
|
|
268
|
+
layers = self.model.model.encoder.layers
|
|
269
|
+
|
|
270
|
+
if layers is None:
|
|
271
|
+
print("WARNING: Could not find layers to checkpoint. Memory will explode.")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
print(f"Checkpointing {len(layers)} layers.")
|
|
275
|
+
for layer in layers:
|
|
276
|
+
checkpoint_fn(layer)
|
|
277
|
+
|
|
278
|
+
### TODO : optionally checkpoint other layers (head, classifier)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _compute_loss(self, batch_inputs):
|
|
282
|
+
"""Compute the loss for training"""
|
|
283
|
+
outputs = self.model(**batch_inputs)
|
|
284
|
+
return mx.mean(outputs["loss"])
|
|
285
|
+
|
|
286
|
+
def _get_collator(self) -> DataCollator:
|
|
287
|
+
if self.task_type == "masked-lm":
|
|
288
|
+
from .collators import DataCollatorForMaskedLanguageModeling
|
|
289
|
+
return DataCollatorForMaskedLanguageModeling(
|
|
290
|
+
tokenizer=self.tokenizer,
|
|
291
|
+
max_length=self.args.max_length
|
|
292
|
+
)
|
|
293
|
+
elif self.task_type == "text-classification":
|
|
294
|
+
from .collators import DataCollatorForSequenceClassification
|
|
295
|
+
# For decoder-based models:
|
|
296
|
+
# the collator will apply chat template in priority if specified
|
|
297
|
+
# if not, it will force the separator if specified
|
|
298
|
+
# if not, it will use the tokenizer default
|
|
299
|
+
return DataCollatorForSequenceClassification(
|
|
300
|
+
tokenizer=self.tokenizer,
|
|
301
|
+
max_length=self.args.max_length,
|
|
302
|
+
use_chat_template=self.use_chat_template,
|
|
303
|
+
force_separator=self.force_separator,
|
|
304
|
+
label2id=self.label2id
|
|
305
|
+
)
|
|
306
|
+
elif self.task_type == "token-classification":
|
|
307
|
+
from .collators import DataCollatorForTokenClassification
|
|
308
|
+
return DataCollatorForTokenClassification(
|
|
309
|
+
tokenizer=self.tokenizer,
|
|
310
|
+
max_length=self.args.max_length,
|
|
311
|
+
label2id=self.label2id
|
|
312
|
+
)
|
|
313
|
+
elif self.task_type == "sentence-similarity" or self.task_type == "sentence-transformers":
|
|
314
|
+
from .collators import DataCollatorForSentenceSimilarity
|
|
315
|
+
return DataCollatorForSentenceSimilarity(
|
|
316
|
+
tokenizer=self.tokenizer,
|
|
317
|
+
max_length=self.args.max_length
|
|
318
|
+
)
|
|
319
|
+
# TODO : Add other tasks & collators if needed
|
|
320
|
+
raise ValueError(f"No collator defined for {self.task_type}")
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _create_batches(self, dataset, batch_size, shuffle=False, seed=42):
|
|
324
|
+
"""
|
|
325
|
+
Iterates over HF dataset, slices it, and passes to collator.
|
|
326
|
+
"""
|
|
327
|
+
data_len = len(dataset)
|
|
328
|
+
|
|
329
|
+
# Use HF dataset's efficient shuffle which works with memory mapping
|
|
330
|
+
if shuffle:
|
|
331
|
+
dataset = dataset.shuffle(seed=seed)
|
|
332
|
+
|
|
333
|
+
# Standard iteration
|
|
334
|
+
for start_idx in range(0, data_len, batch_size):
|
|
335
|
+
end_idx = min(start_idx + batch_size, data_len)
|
|
336
|
+
yield dataset[start_idx:end_idx]
|
|
337
|
+
|
|
338
|
+
def train(self):
|
|
339
|
+
"""Main training loop."""
|
|
340
|
+
print("Starting training...")
|
|
341
|
+
|
|
342
|
+
for epoch in range(self.args.num_train_epochs):
|
|
343
|
+
self.epoch = epoch
|
|
344
|
+
print(f"\nEpoch {epoch + 1}/{self.args.num_train_epochs}")
|
|
345
|
+
self._train_epoch()
|
|
346
|
+
|
|
347
|
+
if self.eval_dataset is not None:
|
|
348
|
+
print(f"Evaluating after epoch {self.epoch + 1}...")
|
|
349
|
+
metrics = self.evaluate()
|
|
350
|
+
self._save_checkpoint(metrics)
|
|
351
|
+
else:
|
|
352
|
+
# Save checkpoint even if no eval dataset is provided
|
|
353
|
+
print(f"Saving checkpoint after epoch {self.epoch + 1} without evaluation...")
|
|
354
|
+
self._save_checkpoint({})
|
|
355
|
+
|
|
356
|
+
def _train_epoch(self):
|
|
357
|
+
"""Training logic for one epoch."""
|
|
358
|
+
self.model.train()
|
|
359
|
+
running_loss = 0
|
|
360
|
+
running_grad_norm = 0.0
|
|
361
|
+
n_steps = 0
|
|
362
|
+
start_time = time.time()
|
|
363
|
+
|
|
364
|
+
# Accumulation container
|
|
365
|
+
accumulated_grads = None
|
|
366
|
+
steps_to_accumulate = self.args.gradient_accumulation_steps
|
|
367
|
+
scale_factor = 1.0 / steps_to_accumulate if steps_to_accumulate > 1 else 1.0
|
|
368
|
+
|
|
369
|
+
# ensures different shuffling each epoch
|
|
370
|
+
current_seed = 42 + self.epoch
|
|
371
|
+
|
|
372
|
+
for raw_batch in self._create_batches(self.train_dataset, self.args.batch_size, shuffle=True, seed=current_seed):
|
|
373
|
+
|
|
374
|
+
self.global_step += 1
|
|
375
|
+
|
|
376
|
+
# Skip steps if resuming from a specific step
|
|
377
|
+
if self.global_step <= self.resume_from_step:
|
|
378
|
+
continue
|
|
379
|
+
|
|
380
|
+
# HF Dataset slicing returns a Dict of lists: {'text': ['a', 'b'], 'label': [0, 1]}
|
|
381
|
+
# Convert HF Columnar batch (Dict[str, List]) to MLX batch (Dict[str, mx.array])
|
|
382
|
+
batch = self.data_collator(raw_batch)
|
|
383
|
+
n_steps += 1
|
|
384
|
+
|
|
385
|
+
# Calculate Grads
|
|
386
|
+
loss, grads = self.step_calc(batch)
|
|
387
|
+
|
|
388
|
+
if accumulated_grads is None:
|
|
389
|
+
accumulated_grads = grads
|
|
390
|
+
else:
|
|
391
|
+
accumulated_grads = tree_map(lambda x, y: x + y, accumulated_grads, grads)
|
|
392
|
+
|
|
393
|
+
# depending on hardware and model size, we may want to avoid syncing here
|
|
394
|
+
running_loss += loss.item() # running_loss += loss to avoid sync
|
|
395
|
+
|
|
396
|
+
# Update Optimizer if Accumulation Done
|
|
397
|
+
if n_steps % steps_to_accumulate == 0:
|
|
398
|
+
|
|
399
|
+
# Scale Grads for Accumulation (only once per accumulation cycle)
|
|
400
|
+
if steps_to_accumulate > 1:
|
|
401
|
+
accumulated_grads = tree_map(lambda g: g * scale_factor, accumulated_grads)
|
|
402
|
+
|
|
403
|
+
# Apply updates
|
|
404
|
+
grad_norm = self.step_update(accumulated_grads)
|
|
405
|
+
running_grad_norm += grad_norm.item()
|
|
406
|
+
|
|
407
|
+
# Reset
|
|
408
|
+
accumulated_grads = None
|
|
409
|
+
|
|
410
|
+
# Eval state to actually trigger the computation graph
|
|
411
|
+
mx.eval(self.model.state, self.optimizer.state)
|
|
412
|
+
|
|
413
|
+
if self.global_step >= self.next_log_step:
|
|
414
|
+
# if running_loss is mx.array (see comment on hardware above), convert to float
|
|
415
|
+
if isinstance(running_loss, mx.array):
|
|
416
|
+
running_loss = running_loss.item()
|
|
417
|
+
|
|
418
|
+
avg_loss = running_loss / max(n_steps, 1)
|
|
419
|
+
avg_grad_norm = running_grad_norm / (max(n_steps, 1) / steps_to_accumulate)
|
|
420
|
+
|
|
421
|
+
# Handle both static float and dynamic schedule
|
|
422
|
+
if callable(self.optimizer.learning_rate):
|
|
423
|
+
# We must pass the optimizer step index
|
|
424
|
+
current_lr = self.optimizer.learning_rate(self.optimizer.step)
|
|
425
|
+
else:
|
|
426
|
+
current_lr = self.optimizer.learning_rate
|
|
427
|
+
if isinstance(current_lr, mx.array):
|
|
428
|
+
current_lr = current_lr.item()
|
|
429
|
+
|
|
430
|
+
mem_gb = mx.get_active_memory() / 1e9
|
|
431
|
+
elapsed = time.time() - start_time
|
|
432
|
+
steps_per_sec = n_steps / elapsed
|
|
433
|
+
|
|
434
|
+
print(
|
|
435
|
+
f"Step {self.global_step} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | GradNorm: {avg_grad_norm:.2f} | Mem: {mem_gb:.1f}GB | Speed: {steps_per_sec:.2f} steps/s"
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Reset window counters
|
|
439
|
+
self.next_log_step += self.logging_steps
|
|
440
|
+
running_loss = 0.0
|
|
441
|
+
running_grad_norm = 0.0
|
|
442
|
+
n_steps = 0
|
|
443
|
+
start_time = time.time()
|
|
444
|
+
|
|
445
|
+
if self.global_step >= self.next_save_step:
|
|
446
|
+
print("Saving checkpoint...")
|
|
447
|
+
self._save_checkpoint({"step": self.global_step, "step_loss": avg_loss, "grad_norm": avg_grad_norm, "learning_rate": current_lr, "memory_gb": mem_gb, "steps_per_sec": steps_per_sec})
|
|
448
|
+
self.next_save_step += self.save_steps
|
|
449
|
+
|
|
450
|
+
# May not be optimal from a speed perspective but MLX is very aggressive in terms of memory caching
|
|
451
|
+
# Like for the utils/server, we force garbage collection here to avoid OOMs on large models
|
|
452
|
+
gc.collect()
|
|
453
|
+
mx.clear_cache()
|
|
454
|
+
|
|
455
|
+
return 0.0 # placeholder
|
|
456
|
+
|
|
457
|
+
def evaluate(self):
|
|
458
|
+
"""Evaluation loop."""
|
|
459
|
+
self.model.eval()
|
|
460
|
+
total_loss = 0
|
|
461
|
+
n_steps = 0
|
|
462
|
+
|
|
463
|
+
for raw_batch in self._create_batches(self.eval_dataset, self.args.eval_batch_size):
|
|
464
|
+
batch = self.data_collator(raw_batch)
|
|
465
|
+
outputs = self.model(**batch)
|
|
466
|
+
loss = mx.mean(outputs["loss"])
|
|
467
|
+
total_loss += loss.item()
|
|
468
|
+
n_steps += 1
|
|
469
|
+
mx.clear_cache()
|
|
470
|
+
|
|
471
|
+
metrics = {"eval_loss": total_loss / n_steps}
|
|
472
|
+
print(f"\nEvaluation metrics: {metrics}")
|
|
473
|
+
|
|
474
|
+
return metrics
|
|
475
|
+
|
|
476
|
+
def test(self, test_dataset=None):
|
|
477
|
+
"""
|
|
478
|
+
Evaluate the model on the test set after training is complete.
|
|
479
|
+
Args: test_dataset: Optional test dataset. If None, uses self.eval_dataset
|
|
480
|
+
"""
|
|
481
|
+
print("\nPerforming final evaluation on test set...")
|
|
482
|
+
|
|
483
|
+
# Save the model's training state
|
|
484
|
+
training = self.model.training
|
|
485
|
+
self.model.eval()
|
|
486
|
+
total_loss = 0
|
|
487
|
+
n_steps = 0
|
|
488
|
+
|
|
489
|
+
# Use provided test dataset or fall back to eval dataset
|
|
490
|
+
dataset_to_test = test_dataset or self.eval_dataset
|
|
491
|
+
if dataset_to_test is None:
|
|
492
|
+
raise ValueError("No test dataset provided")
|
|
493
|
+
|
|
494
|
+
# Perform evaluation
|
|
495
|
+
for raw_batch in self._create_batches(dataset_to_test, self.args.eval_batch_size):
|
|
496
|
+
batch = self.data_collator(raw_batch)
|
|
497
|
+
outputs = self.model(**batch)
|
|
498
|
+
loss = mx.mean(outputs["loss"])
|
|
499
|
+
total_loss += loss.item()
|
|
500
|
+
n_steps += 1
|
|
501
|
+
mx.clear_cache()
|
|
502
|
+
metrics = {"eval_loss": total_loss / n_steps}
|
|
503
|
+
|
|
504
|
+
# Save test results
|
|
505
|
+
results_path = self.output_dir / "test_results.json"
|
|
506
|
+
with open(results_path, "w") as f:
|
|
507
|
+
json.dump(metrics, f, indent=2)
|
|
508
|
+
|
|
509
|
+
print(f"Test results: {metrics}")
|
|
510
|
+
|
|
511
|
+
# Restore model's training state
|
|
512
|
+
self.model.train(training)
|
|
513
|
+
|
|
514
|
+
return metrics
|
|
515
|
+
|
|
516
|
+
def _save_checkpoint(self, metrics: Dict[str, float]):
|
|
517
|
+
save_path = self.output_dir / f"checkpoint-{self.global_step}"
|
|
518
|
+
save_path.mkdir(exist_ok=True)
|
|
519
|
+
|
|
520
|
+
hf_transformers_arch = self.model.get_hf_transformers_arch()
|
|
521
|
+
if hf_transformers_arch:
|
|
522
|
+
self.model.config.architectures = [hf_transformers_arch]
|
|
523
|
+
|
|
524
|
+
with open(save_path / "config.json", "w") as f:
|
|
525
|
+
json.dump(self.model.config.__dict__, f, indent=2)
|
|
526
|
+
|
|
527
|
+
model_card_kwargs = {
|
|
528
|
+
"pipeline": self.task_type,
|
|
529
|
+
"model_path": save_path, # TODO : replace by upload repo id
|
|
530
|
+
"base_model": self.model.config.model_type, # TODO : replace by base model name
|
|
531
|
+
}
|
|
532
|
+
if hasattr(self.model.config, "use_late_interaction"):
|
|
533
|
+
model_card_kwargs["use_late_interaction"] = self.model.config.use_late_interaction
|
|
534
|
+
if hasattr(self.model.config, "is_regression"):
|
|
535
|
+
model_card_kwargs["is_regression"] = self.model.config.is_regression
|
|
536
|
+
|
|
537
|
+
card_text = get_code_for_trained_model(**model_card_kwargs)
|
|
538
|
+
with open(save_path / "README.md", "w") as f:
|
|
539
|
+
f.write(card_text)
|
|
540
|
+
|
|
541
|
+
self.tokenizer.save_pretrained(save_path)
|
|
542
|
+
|
|
543
|
+
weights = dict(tree_flatten(self.model.parameters()))
|
|
544
|
+
if hasattr(self.model, "decoder") :
|
|
545
|
+
print("Removing tied decoder weights from checkpoint...")
|
|
546
|
+
weights.pop("decoder.weight", None)
|
|
547
|
+
mx.save_safetensors(str(save_path / "model.safetensors"), weights)
|
|
548
|
+
|
|
549
|
+
with open(save_path / "metrics.json", "w") as f:
|
|
550
|
+
json.dump(metrics, f, indent=2)
|
|
551
|
+
|
|
552
|
+
# Push to Hub (PLACEHOLDER)
|
|
553
|
+
if self.args.push_to_hub:
|
|
554
|
+
### TODO
|
|
555
|
+
repo_id = self.args.output_dir.split("/")[-1] # Simple heuristic
|
|
556
|
+
print(f"Pushing to hub: {repo_id}")
|
|
557
|
+
upload_to_hub(
|
|
558
|
+
path=str(save_path),
|
|
559
|
+
upload_repo=repo_id,
|
|
560
|
+
hf_path=self.model.config.model_type, # Or base model name
|
|
561
|
+
task_type=self.task_type,
|
|
562
|
+
card_text=card_text
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
# Manage checkpoint rotation
|
|
566
|
+
if self.args.save_total_limit:
|
|
567
|
+
### TODO
|
|
568
|
+
raise NotImplementedError("Checkpoint rotation not implemented yet")
|
|
569
|
+
self._rotate_checkpoints()
|
|
570
|
+
|
|
571
|
+
def _save_config(self):
|
|
572
|
+
"""Save training configuration."""
|
|
573
|
+
config = {
|
|
574
|
+
"model_type": self.model.__class__.__name__,
|
|
575
|
+
"training_args": vars(self.args)
|
|
576
|
+
}
|
|
577
|
+
with open(self.output_dir / "training_config.json", "w") as f:
|
|
578
|
+
json.dump(config, f, indent=2)
|
|
579
|
+
|
|
580
|
+
def upload_to_hub(
|
|
581
|
+
path: str,
|
|
582
|
+
upload_repo: str,
|
|
583
|
+
hf_path: str,
|
|
584
|
+
task_type: str,
|
|
585
|
+
card_text: str,
|
|
586
|
+
):
|
|
587
|
+
"""
|
|
588
|
+
Uploads the model to Hugging Face hub.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
path (str): Local path to the model.
|
|
592
|
+
upload_repo (str): Name of the HF repo to upload to.
|
|
593
|
+
hf_path (str): Path to the original Hugging Face model.
|
|
594
|
+
task_type (str): Type of task the model was trained on.
|
|
595
|
+
"""
|
|
596
|
+
import os
|
|
597
|
+
|
|
598
|
+
from huggingface_hub import HfApi, ModelCard, logging
|
|
599
|
+
|
|
600
|
+
from . import __version__
|
|
601
|
+
|
|
602
|
+
model_path = Path(path)
|
|
603
|
+
|
|
604
|
+
card = ModelCard.load(hf_path) if ModelCard.exist_in_hub(hf_path) else ModelCard()
|
|
605
|
+
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
|
|
606
|
+
card.data.base_model = hf_path
|
|
607
|
+
card.data.task_type = task_type
|
|
608
|
+
|
|
609
|
+
card.text = card_text
|
|
610
|
+
# Overwrite README.md to add metadata
|
|
611
|
+
card.save(model_path / "README.md")
|
|
612
|
+
|
|
613
|
+
logging.set_verbosity_info()
|
|
614
|
+
|
|
615
|
+
api = HfApi()
|
|
616
|
+
api.create_repo(repo_id=upload_repo, exist_ok=True)
|
|
617
|
+
api.upload_folder(
|
|
618
|
+
folder_path=path,
|
|
619
|
+
repo_id=upload_repo,
|
|
620
|
+
repo_type="model",
|
|
621
|
+
multi_commits=True,
|
|
622
|
+
multi_commits_verbose=True,
|
|
623
|
+
)
|
|
624
|
+
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
|
|
625
|
+
|
|
626
|
+
## COMMENTED OUT FOR NOW (Sharding not needing for small models)
|
|
627
|
+
# def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
|
|
628
|
+
# """
|
|
629
|
+
# Splits the weights into smaller shards.
|
|
630
|
+
|
|
631
|
+
# Args:
|
|
632
|
+
# weights (dict): Model weights.
|
|
633
|
+
# max_file_size_gb (int): Maximum size of each shard in gigabytes.
|
|
634
|
+
|
|
635
|
+
# Returns:
|
|
636
|
+
# list: List of weight shards.
|
|
637
|
+
# """
|
|
638
|
+
# max_file_size_bytes = max_file_size_gb << 30
|
|
639
|
+
# shards = []
|
|
640
|
+
# shard, shard_size = {}, 0
|
|
641
|
+
# for k, v in weights.items():
|
|
642
|
+
# if shard_size + v.nbytes > max_file_size_bytes:
|
|
643
|
+
# shards.append(shard)
|
|
644
|
+
# shard, shard_size = {}, 0
|
|
645
|
+
# shard[k] = v
|
|
646
|
+
# shard_size += v.nbytes
|
|
647
|
+
# shards.append(shard)
|
|
648
|
+
# return shards
|