mlx-raclate 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,648 @@
1
+ import time
2
+ import json
3
+ import gc
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional, Dict
8
+ from functools import partial
9
+
10
+ import mlx.core as mx
11
+ import mlx.nn as nn
12
+ import mlx.optimizers
13
+ from datasets import Dataset as HFDataset
14
+
15
+ from mlx.utils import tree_flatten, tree_map
16
+
17
+ from .collators import DataCollator
18
+ from .utils import EMBEDDING_LAYER_NAMES, build_schedule
19
+ from mlx_raclate.tuner.model_card_utils import get_code_for_trained_model
20
+
21
+ @dataclass
22
+ class TrainingArgs:
23
+
24
+ def __init__(
25
+ self,
26
+ batch_size: int = 2,
27
+ eval_batch_size: int = 4,
28
+ max_length: int = 512,
29
+ resume_from_step: int = 0,
30
+ num_train_epochs: int = 2,
31
+ learning_rate: float = 3e-5,
32
+ weight_decay: float = 0.01,
33
+ freeze_embeddings: bool = False,
34
+ warmup_ratio: float = 0,
35
+ warmup_steps: int = 0, # warmup steps take precedence over warmup ratio, warmup_steps are optimizer steps (dataset size / (batch_size * grad_accumulation))
36
+ lr_scheduler_type: str = "constant", # "cosine_decay", "linear_schedule", https://ml-explore.github.io/mlx/build/html/python/optimizers/schedulers.html
37
+ min_lr: float = 0.0, # minimum learning rate for schedulers that need it
38
+ gradient_accumulation_steps: int = 8,
39
+ max_grad_norm: float = 1,
40
+ save_steps: int = 1000,
41
+ logging_steps: int = 100,
42
+ output_dir: str = "outputs",
43
+ save_total_limit: Optional[int] = None,
44
+ grad_checkpoint: bool = True,
45
+ push_to_hub: bool = False,
46
+ ):
47
+ self.batch_size = batch_size
48
+ self.eval_batch_size = eval_batch_size
49
+ self.max_length = max_length
50
+ self.resume_from_step = resume_from_step
51
+ self.num_train_epochs = num_train_epochs
52
+ self.learning_rate = learning_rate
53
+ self.weight_decay = weight_decay
54
+ self.freeze_embeddings = freeze_embeddings
55
+ self.warmup_ratio = warmup_ratio
56
+ self.warmup_steps = warmup_steps
57
+ self.lr_scheduler_type = lr_scheduler_type
58
+ self.min_lr = min_lr
59
+ self.gradient_accumulation_steps = gradient_accumulation_steps
60
+ self.max_grad_norm = max_grad_norm
61
+ self.save_steps = save_steps
62
+ self.logging_steps = logging_steps
63
+ self.output_dir = output_dir
64
+ self.save_total_limit = save_total_limit
65
+ self.grad_checkpoint = grad_checkpoint ### mat not be necessary but helps anticipating hardware constraints
66
+ self.push_to_hub = push_to_hub
67
+
68
+ class Trainer:
69
+ """
70
+ A trainer that adapts to the model's training objective.
71
+ The training logic is determined by the model's class implementation.
72
+
73
+ TODO : add basemodel and upload repo arguments to upload to HF hub
74
+ """
75
+ def __init__(
76
+ self,
77
+ model: nn.Module,
78
+ tokenizer,
79
+ task_type: str,
80
+ training_args: TrainingArgs,
81
+ train_dataset: HFDataset,
82
+ use_chat_template: bool = False, # for decoder-based models, you may want to use chat templates when preparing the data
83
+ force_separator: Optional[str] = None, # for decoder-based models, you may want to force a specific separator when preparing the data
84
+ eval_dataset: Optional[HFDataset] = None,
85
+ optimizer = None,
86
+ label2id: Optional[Dict[str, int]] = None
87
+ ):
88
+ self.model = model
89
+ self.tokenizer = tokenizer._tokenizer ### tokenizer is a wrapper around the HF tokenizer (see utils/tokenizer_utils.py)
90
+ self.task_type = task_type
91
+
92
+ self.args = training_args
93
+ # Adjust logging and saving steps based on gradient accumulation
94
+ if training_args.logging_steps % training_args.gradient_accumulation_steps != 0:
95
+ closest_multiple = (training_args.logging_steps // training_args.gradient_accumulation_steps) * training_args.gradient_accumulation_steps
96
+ self.logging_steps = closest_multiple if closest_multiple > 0 else training_args.gradient_accumulation_steps
97
+ else:
98
+ self.logging_steps = training_args.logging_steps
99
+ if training_args.save_steps % self.logging_steps != 0:
100
+ closest_multiple = (training_args.save_steps // self.logging_steps ) * self.logging_steps
101
+ self.save_steps = closest_multiple if closest_multiple > 0 else self.logging_steps
102
+ else:
103
+ self.save_steps = training_args.save_steps
104
+
105
+ self.resume_from_step = training_args.resume_from_step
106
+ # TODO : handle resuming from checkpoint (load model + optimizer state)
107
+ # For now, no optimizer state loading
108
+
109
+ self.train_dataset = train_dataset
110
+ self.use_chat_template = use_chat_template
111
+ self.force_separator = force_separator
112
+ self.eval_dataset = eval_dataset
113
+ self.label2id = label2id
114
+ self.data_collator = self._get_collator()
115
+
116
+ if training_args.freeze_embeddings:
117
+ print("Freezing embedding layers.")
118
+ if model.config.model_type in EMBEDDING_LAYER_NAMES:
119
+ model.model.freeze(keys=EMBEDDING_LAYER_NAMES[model.config.model_type])
120
+ else:
121
+ print(f"Warning: No embedding layer names defined for model type {model.config.model_type}. Using common names (embed_tokens, embeddings).")
122
+ model.model.freeze(keys=["embed_tokens", "embeddings"])
123
+
124
+ # Initialize optimizer
125
+ if optimizer is not None:
126
+ self.optimizer = optimizer
127
+ elif training_args.lr_scheduler_type=="constant" and not (training_args.warmup_steps or training_args.warmup_ratio):
128
+ self.optimizer = mlx.optimizers.AdamW(
129
+ learning_rate=training_args.learning_rate,
130
+ weight_decay=training_args.weight_decay
131
+ )
132
+ else:
133
+ # Build learning rate schedule
134
+ steps_per_epoch = len(train_dataset) // training_args.batch_size
135
+ if len(train_dataset) % training_args.batch_size != 0:
136
+ steps_per_epoch += 1
137
+
138
+ # Effective steps considering gradient accumulation
139
+ num_update_steps_per_epoch = max(steps_per_epoch // training_args.gradient_accumulation_steps, 1)
140
+ resumed_update_steps = self.resume_from_step // training_args.gradient_accumulation_steps
141
+ total_update_steps = num_update_steps_per_epoch * training_args.num_train_epochs
142
+ if resumed_update_steps >= total_update_steps:
143
+ raise ValueError("resume_from_step is greater than total training steps. Steps = dataset_size / batch_size * num_epochs")
144
+ max_steps = max(total_update_steps - resumed_update_steps, 0)
145
+
146
+ if training_args.warmup_steps > 0:
147
+ warmup_steps = training_args.warmup_steps
148
+ else:
149
+ warmup_steps = int(max_steps * training_args.warmup_ratio)
150
+
151
+ if self.resume_from_step and warmup_steps <= (self.resume_from_step// training_args.gradient_accumulation_steps):
152
+ warmup_steps = 0
153
+
154
+ decay_steps = max_steps - warmup_steps
155
+
156
+ scheduler_type = training_args.lr_scheduler_type # e.g. "constant", "cosine_decay"
157
+
158
+ # Arguments list depends on the function signature in mlx.optimizers
159
+ if scheduler_type == "constant":
160
+ schedule_args = [training_args.learning_rate]
161
+
162
+ elif scheduler_type == "linear_schedule":
163
+ schedule_args = [training_args.learning_rate, training_args.min_lr if training_args.min_lr else 0.0, decay_steps]
164
+
165
+ elif scheduler_type == "cosine_decay":
166
+ schedule_args = [training_args.learning_rate, decay_steps, training_args.min_lr if training_args.min_lr else 0.0]
167
+ else:
168
+ raise ValueError(f"Unsupported lr_scheduler_type: {scheduler_type}")
169
+
170
+ print(f"Scheduler: {scheduler_type} | Warmup: {warmup_steps} | Total: {max_steps}")
171
+
172
+ schedule_config = {
173
+ "name": scheduler_type,
174
+ "arguments": schedule_args,
175
+ "warmup_steps": warmup_steps,
176
+ "warmup_init": 0.0
177
+ }
178
+
179
+ lr_schedule = build_schedule(schedule_config)
180
+
181
+ self.optimizer = mlx.optimizers.AdamW(
182
+ learning_rate=lr_schedule,
183
+ weight_decay=training_args.weight_decay
184
+ )
185
+
186
+ # Setup output directory
187
+ self.output_dir = Path("trained_models") / training_args.output_dir
188
+ self.output_dir.mkdir(parents=True, exist_ok=True)
189
+
190
+ # Setup training state and output directory
191
+ self.global_step = 0
192
+ self.epoch = 0
193
+ self.next_save_step = self.resume_from_step + self.save_steps
194
+ self.next_log_step = self.resume_from_step + self.logging_steps
195
+
196
+ # Capture state that needs updating (random state for Dropout, etc.)
197
+ self.state = [self.model.state, self.optimizer.state, mx.random.state]
198
+
199
+ # Enable gradient checkpointing if requested
200
+ if training_args.grad_checkpoint:
201
+ self._apply_grad_checkpointing()
202
+
203
+ def loss_fn(model, batch):
204
+ outputs = model(**batch)
205
+ return mx.mean(outputs["loss"])
206
+
207
+ grad_fn = nn.value_and_grad(self.model, loss_fn)
208
+
209
+ @partial(mx.compile, inputs=self.state, outputs=self.state)
210
+ def step_calc(batch):
211
+ loss, grads = grad_fn(self.model, batch)
212
+ return loss, grads
213
+
214
+ self.step_calc = step_calc
215
+
216
+ # Optimizer Update Function
217
+ # We define a function that takes the model and ACCUMULATED grads
218
+ @partial(mx.compile, inputs=self.state, outputs=self.state)
219
+ def update_fn(accumulated_grads):
220
+ # Flatten gradients to compute norm
221
+ flattened_grads = tree_flatten(accumulated_grads)
222
+
223
+ squares = [mx.sum(mx.square(g[1])) for g in flattened_grads]
224
+ total_norm = mx.sqrt(mx.sum(mx.array(squares)))
225
+
226
+ # Conputing clipping coeff
227
+ clip_coeff = training_args.max_grad_norm / (total_norm + 1e-6)
228
+ scale = mx.minimum(1.0, clip_coeff)
229
+
230
+ # Gradient clipping
231
+ accumulated_grads = tree_map(lambda g: g * scale, accumulated_grads)
232
+
233
+ self.optimizer.update(self.model, accumulated_grads)
234
+
235
+ return total_norm
236
+
237
+ self.step_update = update_fn
238
+ self.push_to_hub = training_args.push_to_hub
239
+
240
+ print(f"Training {model.__class__.__name__}")
241
+ # Log model type and config
242
+ self._save_config()
243
+
244
+ def _apply_grad_checkpointing(self):
245
+ """
246
+ Apply gradient checkpointing to the model's forward pass to reduce memory usage.
247
+ Uses MLX's checkpoint mechanism to save memory during backpropagation.
248
+ """
249
+ def checkpoint_fn(module):
250
+ original_call = module.__call__
251
+
252
+ def checkpointed_call(self, **kwargs):
253
+ # Let MLX handle the parameter management, just checkpoint the function call
254
+ return mx.checkpoint(original_call)(self, **kwargs)
255
+
256
+ module.__call__ = checkpointed_call
257
+
258
+ layers = None
259
+
260
+ # Handling various model architectures
261
+ if hasattr(self.model, "layers"):
262
+ layers = self.model.layers
263
+ elif hasattr(self.model, "model"):
264
+ if hasattr(self.model.model, "layers"):
265
+ layers = self.model.model.layers
266
+ elif hasattr(self.model.model, "encoder"): # Others TBC
267
+ if hasattr(self.model.model.encoder, "layers"):
268
+ layers = self.model.model.encoder.layers
269
+
270
+ if layers is None:
271
+ print("WARNING: Could not find layers to checkpoint. Memory will explode.")
272
+ return
273
+
274
+ print(f"Checkpointing {len(layers)} layers.")
275
+ for layer in layers:
276
+ checkpoint_fn(layer)
277
+
278
+ ### TODO : optionally checkpoint other layers (head, classifier)
279
+
280
+
281
+ def _compute_loss(self, batch_inputs):
282
+ """Compute the loss for training"""
283
+ outputs = self.model(**batch_inputs)
284
+ return mx.mean(outputs["loss"])
285
+
286
+ def _get_collator(self) -> DataCollator:
287
+ if self.task_type == "masked-lm":
288
+ from .collators import DataCollatorForMaskedLanguageModeling
289
+ return DataCollatorForMaskedLanguageModeling(
290
+ tokenizer=self.tokenizer,
291
+ max_length=self.args.max_length
292
+ )
293
+ elif self.task_type == "text-classification":
294
+ from .collators import DataCollatorForSequenceClassification
295
+ # For decoder-based models:
296
+ # the collator will apply chat template in priority if specified
297
+ # if not, it will force the separator if specified
298
+ # if not, it will use the tokenizer default
299
+ return DataCollatorForSequenceClassification(
300
+ tokenizer=self.tokenizer,
301
+ max_length=self.args.max_length,
302
+ use_chat_template=self.use_chat_template,
303
+ force_separator=self.force_separator,
304
+ label2id=self.label2id
305
+ )
306
+ elif self.task_type == "token-classification":
307
+ from .collators import DataCollatorForTokenClassification
308
+ return DataCollatorForTokenClassification(
309
+ tokenizer=self.tokenizer,
310
+ max_length=self.args.max_length,
311
+ label2id=self.label2id
312
+ )
313
+ elif self.task_type == "sentence-similarity" or self.task_type == "sentence-transformers":
314
+ from .collators import DataCollatorForSentenceSimilarity
315
+ return DataCollatorForSentenceSimilarity(
316
+ tokenizer=self.tokenizer,
317
+ max_length=self.args.max_length
318
+ )
319
+ # TODO : Add other tasks & collators if needed
320
+ raise ValueError(f"No collator defined for {self.task_type}")
321
+
322
+
323
+ def _create_batches(self, dataset, batch_size, shuffle=False, seed=42):
324
+ """
325
+ Iterates over HF dataset, slices it, and passes to collator.
326
+ """
327
+ data_len = len(dataset)
328
+
329
+ # Use HF dataset's efficient shuffle which works with memory mapping
330
+ if shuffle:
331
+ dataset = dataset.shuffle(seed=seed)
332
+
333
+ # Standard iteration
334
+ for start_idx in range(0, data_len, batch_size):
335
+ end_idx = min(start_idx + batch_size, data_len)
336
+ yield dataset[start_idx:end_idx]
337
+
338
+ def train(self):
339
+ """Main training loop."""
340
+ print("Starting training...")
341
+
342
+ for epoch in range(self.args.num_train_epochs):
343
+ self.epoch = epoch
344
+ print(f"\nEpoch {epoch + 1}/{self.args.num_train_epochs}")
345
+ self._train_epoch()
346
+
347
+ if self.eval_dataset is not None:
348
+ print(f"Evaluating after epoch {self.epoch + 1}...")
349
+ metrics = self.evaluate()
350
+ self._save_checkpoint(metrics)
351
+ else:
352
+ # Save checkpoint even if no eval dataset is provided
353
+ print(f"Saving checkpoint after epoch {self.epoch + 1} without evaluation...")
354
+ self._save_checkpoint({})
355
+
356
+ def _train_epoch(self):
357
+ """Training logic for one epoch."""
358
+ self.model.train()
359
+ running_loss = 0
360
+ running_grad_norm = 0.0
361
+ n_steps = 0
362
+ start_time = time.time()
363
+
364
+ # Accumulation container
365
+ accumulated_grads = None
366
+ steps_to_accumulate = self.args.gradient_accumulation_steps
367
+ scale_factor = 1.0 / steps_to_accumulate if steps_to_accumulate > 1 else 1.0
368
+
369
+ # ensures different shuffling each epoch
370
+ current_seed = 42 + self.epoch
371
+
372
+ for raw_batch in self._create_batches(self.train_dataset, self.args.batch_size, shuffle=True, seed=current_seed):
373
+
374
+ self.global_step += 1
375
+
376
+ # Skip steps if resuming from a specific step
377
+ if self.global_step <= self.resume_from_step:
378
+ continue
379
+
380
+ # HF Dataset slicing returns a Dict of lists: {'text': ['a', 'b'], 'label': [0, 1]}
381
+ # Convert HF Columnar batch (Dict[str, List]) to MLX batch (Dict[str, mx.array])
382
+ batch = self.data_collator(raw_batch)
383
+ n_steps += 1
384
+
385
+ # Calculate Grads
386
+ loss, grads = self.step_calc(batch)
387
+
388
+ if accumulated_grads is None:
389
+ accumulated_grads = grads
390
+ else:
391
+ accumulated_grads = tree_map(lambda x, y: x + y, accumulated_grads, grads)
392
+
393
+ # depending on hardware and model size, we may want to avoid syncing here
394
+ running_loss += loss.item() # running_loss += loss to avoid sync
395
+
396
+ # Update Optimizer if Accumulation Done
397
+ if n_steps % steps_to_accumulate == 0:
398
+
399
+ # Scale Grads for Accumulation (only once per accumulation cycle)
400
+ if steps_to_accumulate > 1:
401
+ accumulated_grads = tree_map(lambda g: g * scale_factor, accumulated_grads)
402
+
403
+ # Apply updates
404
+ grad_norm = self.step_update(accumulated_grads)
405
+ running_grad_norm += grad_norm.item()
406
+
407
+ # Reset
408
+ accumulated_grads = None
409
+
410
+ # Eval state to actually trigger the computation graph
411
+ mx.eval(self.model.state, self.optimizer.state)
412
+
413
+ if self.global_step >= self.next_log_step:
414
+ # if running_loss is mx.array (see comment on hardware above), convert to float
415
+ if isinstance(running_loss, mx.array):
416
+ running_loss = running_loss.item()
417
+
418
+ avg_loss = running_loss / max(n_steps, 1)
419
+ avg_grad_norm = running_grad_norm / (max(n_steps, 1) / steps_to_accumulate)
420
+
421
+ # Handle both static float and dynamic schedule
422
+ if callable(self.optimizer.learning_rate):
423
+ # We must pass the optimizer step index
424
+ current_lr = self.optimizer.learning_rate(self.optimizer.step)
425
+ else:
426
+ current_lr = self.optimizer.learning_rate
427
+ if isinstance(current_lr, mx.array):
428
+ current_lr = current_lr.item()
429
+
430
+ mem_gb = mx.get_active_memory() / 1e9
431
+ elapsed = time.time() - start_time
432
+ steps_per_sec = n_steps / elapsed
433
+
434
+ print(
435
+ f"Step {self.global_step} | Loss: {avg_loss:.4f} | LR: {current_lr:.2e} | GradNorm: {avg_grad_norm:.2f} | Mem: {mem_gb:.1f}GB | Speed: {steps_per_sec:.2f} steps/s"
436
+ )
437
+
438
+ # Reset window counters
439
+ self.next_log_step += self.logging_steps
440
+ running_loss = 0.0
441
+ running_grad_norm = 0.0
442
+ n_steps = 0
443
+ start_time = time.time()
444
+
445
+ if self.global_step >= self.next_save_step:
446
+ print("Saving checkpoint...")
447
+ self._save_checkpoint({"step": self.global_step, "step_loss": avg_loss, "grad_norm": avg_grad_norm, "learning_rate": current_lr, "memory_gb": mem_gb, "steps_per_sec": steps_per_sec})
448
+ self.next_save_step += self.save_steps
449
+
450
+ # May not be optimal from a speed perspective but MLX is very aggressive in terms of memory caching
451
+ # Like for the utils/server, we force garbage collection here to avoid OOMs on large models
452
+ gc.collect()
453
+ mx.clear_cache()
454
+
455
+ return 0.0 # placeholder
456
+
457
+ def evaluate(self):
458
+ """Evaluation loop."""
459
+ self.model.eval()
460
+ total_loss = 0
461
+ n_steps = 0
462
+
463
+ for raw_batch in self._create_batches(self.eval_dataset, self.args.eval_batch_size):
464
+ batch = self.data_collator(raw_batch)
465
+ outputs = self.model(**batch)
466
+ loss = mx.mean(outputs["loss"])
467
+ total_loss += loss.item()
468
+ n_steps += 1
469
+ mx.clear_cache()
470
+
471
+ metrics = {"eval_loss": total_loss / n_steps}
472
+ print(f"\nEvaluation metrics: {metrics}")
473
+
474
+ return metrics
475
+
476
+ def test(self, test_dataset=None):
477
+ """
478
+ Evaluate the model on the test set after training is complete.
479
+ Args: test_dataset: Optional test dataset. If None, uses self.eval_dataset
480
+ """
481
+ print("\nPerforming final evaluation on test set...")
482
+
483
+ # Save the model's training state
484
+ training = self.model.training
485
+ self.model.eval()
486
+ total_loss = 0
487
+ n_steps = 0
488
+
489
+ # Use provided test dataset or fall back to eval dataset
490
+ dataset_to_test = test_dataset or self.eval_dataset
491
+ if dataset_to_test is None:
492
+ raise ValueError("No test dataset provided")
493
+
494
+ # Perform evaluation
495
+ for raw_batch in self._create_batches(dataset_to_test, self.args.eval_batch_size):
496
+ batch = self.data_collator(raw_batch)
497
+ outputs = self.model(**batch)
498
+ loss = mx.mean(outputs["loss"])
499
+ total_loss += loss.item()
500
+ n_steps += 1
501
+ mx.clear_cache()
502
+ metrics = {"eval_loss": total_loss / n_steps}
503
+
504
+ # Save test results
505
+ results_path = self.output_dir / "test_results.json"
506
+ with open(results_path, "w") as f:
507
+ json.dump(metrics, f, indent=2)
508
+
509
+ print(f"Test results: {metrics}")
510
+
511
+ # Restore model's training state
512
+ self.model.train(training)
513
+
514
+ return metrics
515
+
516
+ def _save_checkpoint(self, metrics: Dict[str, float]):
517
+ save_path = self.output_dir / f"checkpoint-{self.global_step}"
518
+ save_path.mkdir(exist_ok=True)
519
+
520
+ hf_transformers_arch = self.model.get_hf_transformers_arch()
521
+ if hf_transformers_arch:
522
+ self.model.config.architectures = [hf_transformers_arch]
523
+
524
+ with open(save_path / "config.json", "w") as f:
525
+ json.dump(self.model.config.__dict__, f, indent=2)
526
+
527
+ model_card_kwargs = {
528
+ "pipeline": self.task_type,
529
+ "model_path": save_path, # TODO : replace by upload repo id
530
+ "base_model": self.model.config.model_type, # TODO : replace by base model name
531
+ }
532
+ if hasattr(self.model.config, "use_late_interaction"):
533
+ model_card_kwargs["use_late_interaction"] = self.model.config.use_late_interaction
534
+ if hasattr(self.model.config, "is_regression"):
535
+ model_card_kwargs["is_regression"] = self.model.config.is_regression
536
+
537
+ card_text = get_code_for_trained_model(**model_card_kwargs)
538
+ with open(save_path / "README.md", "w") as f:
539
+ f.write(card_text)
540
+
541
+ self.tokenizer.save_pretrained(save_path)
542
+
543
+ weights = dict(tree_flatten(self.model.parameters()))
544
+ if hasattr(self.model, "decoder") :
545
+ print("Removing tied decoder weights from checkpoint...")
546
+ weights.pop("decoder.weight", None)
547
+ mx.save_safetensors(str(save_path / "model.safetensors"), weights)
548
+
549
+ with open(save_path / "metrics.json", "w") as f:
550
+ json.dump(metrics, f, indent=2)
551
+
552
+ # Push to Hub (PLACEHOLDER)
553
+ if self.args.push_to_hub:
554
+ ### TODO
555
+ repo_id = self.args.output_dir.split("/")[-1] # Simple heuristic
556
+ print(f"Pushing to hub: {repo_id}")
557
+ upload_to_hub(
558
+ path=str(save_path),
559
+ upload_repo=repo_id,
560
+ hf_path=self.model.config.model_type, # Or base model name
561
+ task_type=self.task_type,
562
+ card_text=card_text
563
+ )
564
+
565
+ # Manage checkpoint rotation
566
+ if self.args.save_total_limit:
567
+ ### TODO
568
+ raise NotImplementedError("Checkpoint rotation not implemented yet")
569
+ self._rotate_checkpoints()
570
+
571
+ def _save_config(self):
572
+ """Save training configuration."""
573
+ config = {
574
+ "model_type": self.model.__class__.__name__,
575
+ "training_args": vars(self.args)
576
+ }
577
+ with open(self.output_dir / "training_config.json", "w") as f:
578
+ json.dump(config, f, indent=2)
579
+
580
+ def upload_to_hub(
581
+ path: str,
582
+ upload_repo: str,
583
+ hf_path: str,
584
+ task_type: str,
585
+ card_text: str,
586
+ ):
587
+ """
588
+ Uploads the model to Hugging Face hub.
589
+
590
+ Args:
591
+ path (str): Local path to the model.
592
+ upload_repo (str): Name of the HF repo to upload to.
593
+ hf_path (str): Path to the original Hugging Face model.
594
+ task_type (str): Type of task the model was trained on.
595
+ """
596
+ import os
597
+
598
+ from huggingface_hub import HfApi, ModelCard, logging
599
+
600
+ from . import __version__
601
+
602
+ model_path = Path(path)
603
+
604
+ card = ModelCard.load(hf_path) if ModelCard.exist_in_hub(hf_path) else ModelCard()
605
+ card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
606
+ card.data.base_model = hf_path
607
+ card.data.task_type = task_type
608
+
609
+ card.text = card_text
610
+ # Overwrite README.md to add metadata
611
+ card.save(model_path / "README.md")
612
+
613
+ logging.set_verbosity_info()
614
+
615
+ api = HfApi()
616
+ api.create_repo(repo_id=upload_repo, exist_ok=True)
617
+ api.upload_folder(
618
+ folder_path=path,
619
+ repo_id=upload_repo,
620
+ repo_type="model",
621
+ multi_commits=True,
622
+ multi_commits_verbose=True,
623
+ )
624
+ print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
625
+
626
+ ## COMMENTED OUT FOR NOW (Sharding not needing for small models)
627
+ # def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
628
+ # """
629
+ # Splits the weights into smaller shards.
630
+
631
+ # Args:
632
+ # weights (dict): Model weights.
633
+ # max_file_size_gb (int): Maximum size of each shard in gigabytes.
634
+
635
+ # Returns:
636
+ # list: List of weight shards.
637
+ # """
638
+ # max_file_size_bytes = max_file_size_gb << 30
639
+ # shards = []
640
+ # shard, shard_size = {}, 0
641
+ # for k, v in weights.items():
642
+ # if shard_size + v.nbytes > max_file_size_bytes:
643
+ # shards.append(shard)
644
+ # shard, shard_size = {}, 0
645
+ # shard[k] = v
646
+ # shard_size += v.nbytes
647
+ # shards.append(shard)
648
+ # return shards