distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,843 @@
1
+ """Main DistilTrainer class for knowledge distillation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any, Callable
9
+
10
+ import torch
11
+ from datasets import Dataset, DatasetDict, load_dataset
12
+ from sklearn.decomposition import PCA
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+ from transformers import (
16
+ PreTrainedModel,
17
+ PreTrainedTokenizer,
18
+ get_scheduler,
19
+ )
20
+
21
+ from sentence_transformers import SentenceTransformer
22
+
23
+ from distil_trainer.core.config import (
24
+ DistilTrainerConfig,
25
+ DistillationConfig,
26
+ LayerReductionConfig,
27
+ WidthPruningConfig,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class DistilTrainer:
34
+ """
35
+ Main trainer class for knowledge distillation.
36
+
37
+ Supports multiple distillation strategies:
38
+ - Classical embedding distillation (MSE/Cosine loss)
39
+ - Layer reduction (depth pruning)
40
+ - Width pruning
41
+ - Combined pruning
42
+
43
+ Example:
44
+ >>> config = DistilTrainerConfig(
45
+ ... teacher_model="sentence-transformers/all-mpnet-base-v2",
46
+ ... student_model="sentence-transformers/paraphrase-TinyBERT-L6-v2",
47
+ ... output_dir="./distilled_model"
48
+ ... )
49
+ >>> trainer = DistilTrainer(config)
50
+ >>> trainer.load_data(train_data="sentence-transformers/all-nli")
51
+ >>> trainer.train()
52
+ >>> trainer.save_model("./final_model")
53
+ """
54
+
55
+ def __init__(self, config: DistilTrainerConfig):
56
+ """
57
+ Initialize the DistilTrainer.
58
+
59
+ Args:
60
+ config: Configuration for distillation training.
61
+ """
62
+ self.config = config
63
+ self.device = self._get_device()
64
+
65
+ # Initialize models
66
+ self.teacher_model = self._load_teacher_model()
67
+ self.student_model = self._initialize_student_model()
68
+
69
+ # PCA for dimension reduction (if needed)
70
+ self.pca = None
71
+ self.teacher_projection = None
72
+
73
+ # Data
74
+ self.train_dataset: Dataset | None = None
75
+ self.eval_dataset: Dataset | None = None
76
+ self.test_dataset: Dataset | None = None
77
+
78
+ # Training state
79
+ self.optimizer = None
80
+ self.scheduler = None
81
+ self.global_step = 0
82
+ self.best_metric = float("inf")
83
+
84
+ logger.info(f"Initialized DistilTrainer with device: {self.device}")
85
+ logger.info(f"Teacher model: {self._get_model_info(self.teacher_model)}")
86
+ logger.info(f"Student model: {self._get_model_info(self.student_model)}")
87
+
88
+ def _get_device(self) -> torch.device:
89
+ """Determine the device to use for training."""
90
+ if self.config.device == "auto":
91
+ if torch.cuda.is_available():
92
+ return torch.device("cuda")
93
+ elif torch.backends.mps.is_available():
94
+ return torch.device("mps")
95
+ else:
96
+ return torch.device("cpu")
97
+ return torch.device(self.config.device)
98
+
99
+ def _load_teacher_model(self) -> SentenceTransformer | PreTrainedModel:
100
+ """Load the teacher model."""
101
+ teacher = self.config.teacher_model
102
+
103
+ if isinstance(teacher, str):
104
+ logger.info(f"Loading teacher model from: {teacher}")
105
+ teacher = SentenceTransformer(teacher)
106
+
107
+ teacher.to(self.device)
108
+ teacher.eval()
109
+ return teacher
110
+
111
+ def _initialize_student_model(self) -> SentenceTransformer | PreTrainedModel:
112
+ """Initialize the student model based on the configured strategy."""
113
+ strategy = self.config.student_init_strategy
114
+
115
+ if strategy == "from_pretrained":
116
+ return self._load_pretrained_student()
117
+ elif strategy in ("layer_reduction", "depth_pruning"):
118
+ return self._create_layer_reduced_student()
119
+ elif strategy == "width_pruning":
120
+ return self._create_width_pruned_student()
121
+ elif strategy == "combined_pruning":
122
+ return self._create_combined_pruned_student()
123
+ else:
124
+ raise ValueError(f"Unknown student initialization strategy: {strategy}")
125
+
126
+ def _load_pretrained_student(self) -> SentenceTransformer | PreTrainedModel:
127
+ """Load a pretrained student model."""
128
+ student = self.config.student_model
129
+
130
+ if student is None:
131
+ raise ValueError("student_model must be provided for 'from_pretrained' strategy")
132
+
133
+ if isinstance(student, str):
134
+ logger.info(f"Loading student model from: {student}")
135
+ student = SentenceTransformer(student)
136
+
137
+ student.to(self.device)
138
+ return student
139
+
140
+ def _create_layer_reduced_student(self) -> SentenceTransformer | PreTrainedModel:
141
+ """Create a student model by removing layers from the teacher."""
142
+ from distil_trainer.pruning import DepthPruner
143
+
144
+ pruning_config = self.config.pruning_config
145
+ if not isinstance(pruning_config, LayerReductionConfig):
146
+ raise ValueError("pruning_config must be LayerReductionConfig for layer reduction")
147
+
148
+ logger.info("Creating layer-reduced student from teacher")
149
+
150
+ # Clone teacher model
151
+ student = SentenceTransformer(self.config.teacher_model)
152
+
153
+ # Apply depth pruning
154
+ pruner = DepthPruner(student)
155
+ student = pruner.prune(
156
+ layers_to_keep=pruning_config.layers_to_keep,
157
+ num_layers_to_keep=pruning_config.num_layers_to_keep,
158
+ layers_to_drop=pruning_config.layers_to_drop,
159
+ layer_selection=pruning_config.layer_selection,
160
+ )
161
+
162
+ student.to(self.device)
163
+ return student
164
+
165
+ def _create_width_pruned_student(self) -> SentenceTransformer | PreTrainedModel:
166
+ """Create a student model by pruning width dimensions."""
167
+ from distil_trainer.pruning import WidthPruner
168
+
169
+ pruning_config = self.config.pruning_config
170
+ if not isinstance(pruning_config, WidthPruningConfig):
171
+ raise ValueError("pruning_config must be WidthPruningConfig for width pruning")
172
+
173
+ logger.info("Creating width-pruned student from teacher")
174
+
175
+ # Clone teacher model
176
+ student = SentenceTransformer(self.config.teacher_model)
177
+
178
+ # Apply width pruning
179
+ pruner = WidthPruner(student)
180
+ student = pruner.prune(pruning_config)
181
+
182
+ student.to(self.device)
183
+ return student
184
+
185
+ def _create_combined_pruned_student(self) -> SentenceTransformer | PreTrainedModel:
186
+ """Create a student model using both depth and width pruning."""
187
+ from distil_trainer.pruning import CombinedPruner
188
+
189
+ pruning_config = self.config.pruning_config
190
+
191
+ logger.info("Creating combined-pruned student from teacher")
192
+
193
+ # Clone teacher model
194
+ student = SentenceTransformer(self.config.teacher_model)
195
+
196
+ # Apply combined pruning
197
+ pruner = CombinedPruner(student)
198
+ student = pruner.prune(pruning_config)
199
+
200
+ student.to(self.device)
201
+ return student
202
+
203
+ def _get_model_info(self, model: SentenceTransformer | PreTrainedModel) -> str:
204
+ """Get a string representation of model info."""
205
+ if isinstance(model, SentenceTransformer):
206
+ num_params = sum(p.numel() for p in model.parameters())
207
+ embedding_dim = model.get_sentence_embedding_dimension()
208
+ return f"SentenceTransformer(params={num_params:,}, dim={embedding_dim})"
209
+ else:
210
+ num_params = sum(p.numel() for p in model.parameters())
211
+ return f"PreTrainedModel(params={num_params:,})"
212
+
213
+ def load_data(
214
+ self,
215
+ train_data: str | Dataset | None = None,
216
+ eval_data: str | Dataset | None = None,
217
+ test_data: str | Dataset | None = None,
218
+ text_column: str | None = None,
219
+ max_samples: int | None = None,
220
+ ) -> None:
221
+ """
222
+ Load training, evaluation, and test datasets.
223
+
224
+ Args:
225
+ train_data: Path or name of training dataset, or Dataset object.
226
+ eval_data: Path or name of evaluation dataset, or Dataset object.
227
+ test_data: Path or name of test dataset, or Dataset object.
228
+ text_column: Name of the column containing text/sentences.
229
+ Overrides config.data_config.text_column if provided.
230
+ max_samples: Maximum number of samples to use from the dataset.
231
+ Useful for quick testing. Overrides config.data_config.max_samples.
232
+ """
233
+ if text_column is not None:
234
+ self.config.data_config.text_column = text_column
235
+ logger.info(f"Set text_column to: {text_column}")
236
+
237
+ if max_samples is not None:
238
+ self.config.data_config.max_samples = max_samples
239
+ logger.info(f"Set max_samples to: {max_samples}")
240
+
241
+ if train_data is not None:
242
+ self.train_dataset = self._load_dataset(train_data, "train")
243
+ logger.info(f"Loaded training dataset: {len(self.train_dataset)} samples")
244
+
245
+ if eval_data is not None:
246
+ self.eval_dataset = self._load_dataset(eval_data, "validation")
247
+ logger.info(f"Loaded evaluation dataset: {len(self.eval_dataset)} samples")
248
+
249
+ if test_data is not None:
250
+ self.test_dataset = self._load_dataset(test_data, "test")
251
+ logger.info(f"Loaded test dataset: {len(self.test_dataset)} samples")
252
+
253
+ def _load_dataset(self, data: str | Dataset, split: str = "train") -> Dataset:
254
+ """Load a dataset from a path or name."""
255
+ if isinstance(data, Dataset):
256
+ dataset = data
257
+ else:
258
+ logger.info(f"Loading dataset: {data}")
259
+
260
+ try:
261
+ dataset = load_dataset(data, split=split)
262
+ except Exception:
263
+ # Try loading as a DatasetDict and getting the split
264
+ dataset_dict = load_dataset(data)
265
+ if isinstance(dataset_dict, DatasetDict):
266
+ if split in dataset_dict:
267
+ dataset = dataset_dict[split]
268
+ else:
269
+ # Use the first available split
270
+ dataset = list(dataset_dict.values())[0]
271
+ else:
272
+ dataset = dataset_dict
273
+
274
+ # Apply max_samples limit if configured
275
+ max_samples = self.config.data_config.max_samples
276
+ if max_samples is not None and max_samples > 0:
277
+ original_size = len(dataset)
278
+ if max_samples < original_size:
279
+ dataset = dataset.select(range(max_samples))
280
+ logger.info(f"Limited dataset from {original_size} to {max_samples} samples")
281
+
282
+ return dataset
283
+
284
+ def setup_pca_projection(self) -> None:
285
+ """Set up PCA projection if student dimension is smaller than teacher."""
286
+ if not isinstance(self.teacher_model, SentenceTransformer):
287
+ return
288
+ if not isinstance(self.student_model, SentenceTransformer):
289
+ return
290
+
291
+ teacher_dim = self.teacher_model.get_sentence_embedding_dimension()
292
+ student_dim = self.student_model.get_sentence_embedding_dimension()
293
+
294
+ if student_dim >= teacher_dim:
295
+ logger.info("Student dimension >= teacher dimension, no PCA needed")
296
+ return
297
+
298
+ logger.info(f"Setting up PCA projection: {teacher_dim} -> {student_dim}")
299
+
300
+ # Collect sample sentences for PCA
301
+ if self.train_dataset is None:
302
+ raise ValueError("Training dataset required for PCA projection")
303
+
304
+ text_column = self.config.data_config.text_column
305
+ num_samples = min(
306
+ self.config.distillation_config.pca_num_samples,
307
+ len(self.train_dataset),
308
+ )
309
+
310
+ sample_sentences = self.train_dataset[:num_samples][text_column]
311
+
312
+ # Compute teacher embeddings
313
+ logger.info(f"Computing teacher embeddings for {num_samples} samples")
314
+ with torch.no_grad():
315
+ embeddings = self.teacher_model.encode(
316
+ sample_sentences,
317
+ convert_to_numpy=True,
318
+ show_progress_bar=True,
319
+ batch_size=self.config.distillation_config.teacher_inference_batch_size,
320
+ )
321
+
322
+ # Fit PCA
323
+ logger.info("Fitting PCA...")
324
+ self.pca = PCA(n_components=student_dim)
325
+ self.pca.fit(embeddings)
326
+
327
+ # Create projection layer for teacher
328
+ from distil_trainer.models import DenseProjection
329
+
330
+ self.teacher_projection = DenseProjection(
331
+ in_features=teacher_dim,
332
+ out_features=student_dim,
333
+ weights=torch.tensor(self.pca.components_, dtype=torch.float32),
334
+ )
335
+ self.teacher_projection.to(self.device)
336
+
337
+ logger.info(f"PCA projection ready: explained variance ratio = {sum(self.pca.explained_variance_ratio_):.4f}")
338
+
339
+ def precompute_teacher_embeddings(self) -> None:
340
+ """Precompute teacher embeddings for the training dataset."""
341
+ if not self.config.distillation_config.precompute_teacher_embeddings:
342
+ return
343
+
344
+ if self.train_dataset is None:
345
+ raise ValueError("Training dataset required")
346
+
347
+ logger.info("Precomputing teacher embeddings...")
348
+
349
+ text_column = self.config.data_config.text_column
350
+ sentences = self.train_dataset[text_column]
351
+
352
+ batch_size = self.config.distillation_config.teacher_inference_batch_size
353
+
354
+ with torch.no_grad():
355
+ embeddings = self.teacher_model.encode(
356
+ sentences,
357
+ convert_to_numpy=False,
358
+ convert_to_tensor=True,
359
+ show_progress_bar=True,
360
+ batch_size=batch_size,
361
+ )
362
+
363
+ # Apply projection if needed
364
+ if self.teacher_projection is not None:
365
+ embeddings = self.teacher_projection(embeddings)
366
+
367
+ # Add embeddings to dataset
368
+ if isinstance(embeddings, torch.Tensor):
369
+ embeddings_list = embeddings.cpu().tolist()
370
+ else:
371
+ # Already a list (e.g., when encode returns list directly)
372
+ embeddings_list = embeddings
373
+ self.train_dataset = self.train_dataset.add_column("label", embeddings_list)
374
+
375
+ logger.info("Teacher embeddings precomputed and cached")
376
+
377
+ def train(self) -> dict[str, float]:
378
+ """
379
+ Run the distillation training.
380
+
381
+ Returns:
382
+ Dictionary of training metrics.
383
+ """
384
+ if self.train_dataset is None:
385
+ raise ValueError("Training dataset required. Call load_data() first.")
386
+
387
+ logger.info("Starting distillation training...")
388
+
389
+ # Setup WandB
390
+ is_wandb_avail = False
391
+ if "wandb" in self.config.training_config.report_to or (self.config.wandb_config.project is not None):
392
+ try:
393
+ import wandb
394
+ from dataclasses import asdict
395
+
396
+ # Check if already initialized
397
+ if wandb.run is None:
398
+ wandb.init(
399
+ project=self.config.wandb_config.project,
400
+ entity=self.config.wandb_config.entity,
401
+ name=self.config.wandb_config.name or self.config.training_config.run_name,
402
+ tags=self.config.wandb_config.tags,
403
+ group=self.config.wandb_config.group,
404
+ notes=self.config.wandb_config.notes,
405
+ config=asdict(self.config),
406
+ )
407
+ is_wandb_avail = True
408
+ except ImportError:
409
+ logger.warning("wandb not installed, skipping logging")
410
+
411
+ # Setup PCA if needed
412
+ if self.config.distillation_config.use_pca_projection:
413
+ self.setup_pca_projection()
414
+
415
+ # Precompute teacher embeddings if enabled
416
+ self.precompute_teacher_embeddings()
417
+
418
+ # Setup optimizer and scheduler
419
+ self._setup_optimizer()
420
+
421
+ # Get loss function
422
+ loss_fn = self._get_loss_function()
423
+
424
+ # Create data loader
425
+ train_dataloader = self._create_dataloader(self.train_dataset, shuffle=True)
426
+
427
+ # Training loop
428
+ training_config = self.config.training_config
429
+ num_epochs = training_config.num_train_epochs
430
+ total_steps = len(train_dataloader) * num_epochs
431
+
432
+ if training_config.max_steps > 0:
433
+ total_steps = min(total_steps, training_config.max_steps)
434
+
435
+ logger.info(f"Training for {num_epochs} epochs, {total_steps} total steps")
436
+
437
+ self.student_model.train()
438
+ self.global_step = 0
439
+
440
+ avg_epoch_loss = 0.0
441
+
442
+ for epoch in range(num_epochs):
443
+ epoch_loss = 0.0
444
+ num_batches = 0
445
+
446
+ progress_bar = tqdm(
447
+ train_dataloader,
448
+ desc=f"Epoch {epoch + 1}/{num_epochs}",
449
+ disable=False,
450
+ )
451
+
452
+ for batch in progress_bar:
453
+ loss = self._training_step(batch, loss_fn)
454
+
455
+ epoch_loss += loss.item()
456
+ num_batches += 1
457
+ self.global_step += 1
458
+
459
+ current_loss = loss.item()
460
+ progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})
461
+
462
+ # Logging
463
+ if self.global_step % training_config.logging_steps == 0:
464
+ avg_loss = epoch_loss / num_batches
465
+ logger.info(f"Step {self.global_step}: loss = {avg_loss:.4f}")
466
+
467
+ if is_wandb_avail:
468
+ wandb.log(
469
+ {
470
+ "train/loss": current_loss,
471
+ "train/avg_loss": avg_loss,
472
+ "train/epoch": epoch + (num_batches / len(train_dataloader)),
473
+ "train/learning_rate": self.scheduler.get_last_lr()[0],
474
+ },
475
+ step=self.global_step
476
+ )
477
+
478
+ # Evaluation
479
+ if (
480
+ training_config.eval_strategy == "steps"
481
+ and self.global_step % training_config.eval_steps == 0
482
+ ):
483
+ eval_metrics = self.evaluate()
484
+ logger.info(f"Step {self.global_step}: {eval_metrics}")
485
+
486
+ if is_wandb_avail:
487
+ wandb_metrics = {f"eval/{k}": v for k, v in eval_metrics.items()}
488
+ wandb.log(wandb_metrics, step=self.global_step)
489
+
490
+ # Save best model
491
+ if self._is_better_metric(eval_metrics):
492
+ self._save_checkpoint("best")
493
+
494
+ # Save checkpoint
495
+ if self.global_step % training_config.save_steps == 0:
496
+ self._save_checkpoint(f"checkpoint-{self.global_step}")
497
+
498
+ # Check max steps
499
+ if training_config.max_steps > 0 and self.global_step >= training_config.max_steps:
500
+ break
501
+
502
+ # End of epoch
503
+ avg_epoch_loss = epoch_loss / num_batches
504
+ logger.info(f"Epoch {epoch + 1} completed: avg_loss = {avg_epoch_loss:.4f}")
505
+
506
+ if is_wandb_avail:
507
+ wandb.log({"train/epoch_loss": avg_epoch_loss}, step=self.global_step)
508
+
509
+ if training_config.max_steps > 0 and self.global_step >= training_config.max_steps:
510
+ break
511
+
512
+ logger.info("Training completed!")
513
+
514
+ # Load best model if configured
515
+ if training_config.load_best_model_at_end:
516
+ self._load_checkpoint("best")
517
+
518
+ # Push to Hub at end if configured
519
+ if self.config.hub_config.push_to_hub:
520
+ self._push_to_hub_with_config()
521
+
522
+ if is_wandb_avail:
523
+ wandb.finish()
524
+
525
+ return {"train_loss": avg_epoch_loss}
526
+
527
+ def _setup_optimizer(self) -> None:
528
+ """Set up optimizer and learning rate scheduler."""
529
+ training_config = self.config.training_config
530
+
531
+ # Optimizer
532
+ if training_config.optimizer == "adamw":
533
+ self.optimizer = torch.optim.AdamW(
534
+ self.student_model.parameters(),
535
+ lr=training_config.learning_rate,
536
+ betas=(training_config.adam_beta1, training_config.adam_beta2),
537
+ eps=training_config.adam_epsilon,
538
+ weight_decay=training_config.weight_decay,
539
+ )
540
+ elif training_config.optimizer == "adam":
541
+ self.optimizer = torch.optim.Adam(
542
+ self.student_model.parameters(),
543
+ lr=training_config.learning_rate,
544
+ )
545
+ elif training_config.optimizer == "sgd":
546
+ self.optimizer = torch.optim.SGD(
547
+ self.student_model.parameters(),
548
+ lr=training_config.learning_rate,
549
+ weight_decay=training_config.weight_decay,
550
+ )
551
+ else:
552
+ raise ValueError(f"Unknown optimizer: {training_config.optimizer}")
553
+
554
+ # Scheduler
555
+ num_training_steps = self._get_num_training_steps()
556
+ warmup_steps = training_config.warmup_steps
557
+ if warmup_steps == 0 and training_config.warmup_ratio > 0:
558
+ warmup_steps = int(num_training_steps * training_config.warmup_ratio)
559
+
560
+ self.scheduler = get_scheduler(
561
+ training_config.lr_scheduler_type,
562
+ optimizer=self.optimizer,
563
+ num_warmup_steps=warmup_steps,
564
+ num_training_steps=num_training_steps,
565
+ )
566
+
567
+ def _get_num_training_steps(self) -> int:
568
+ """Calculate the total number of training steps."""
569
+ if self.train_dataset is None:
570
+ return 0
571
+
572
+ training_config = self.config.training_config
573
+ num_batches = len(self.train_dataset) // training_config.per_device_train_batch_size
574
+ total_steps = num_batches * training_config.num_train_epochs
575
+
576
+ if training_config.max_steps > 0:
577
+ total_steps = min(total_steps, training_config.max_steps)
578
+
579
+ return total_steps
580
+
581
+ def _get_loss_function(self) -> Callable:
582
+ """Get the loss function based on configuration."""
583
+ from distil_trainer.distillation import DistillationLosses
584
+
585
+ loss_type = self.config.distillation_config.loss_type
586
+
587
+ if loss_type == "mse":
588
+ return DistillationLosses.mse_loss
589
+ elif loss_type == "cosine":
590
+ return DistillationLosses.cosine_loss
591
+ elif loss_type == "kl_divergence":
592
+ temperature = self.config.distillation_config.temperature
593
+ return lambda s, t: DistillationLosses.kl_divergence_loss(s, t, temperature)
594
+ elif loss_type == "combined":
595
+ from distil_trainer.distillation import CombinedDistillationLoss
596
+
597
+ return CombinedDistillationLoss(
598
+ logit_weight=self.config.distillation_config.logit_loss_weight,
599
+ embedding_weight=self.config.distillation_config.embedding_loss_weight,
600
+ intermediate_weight=self.config.distillation_config.intermediate_loss_weight,
601
+ attention_weight=self.config.distillation_config.attention_loss_weight,
602
+ temperature=self.config.distillation_config.temperature,
603
+ layer_mapping=self.config.distillation_config.layer_mapping,
604
+ )
605
+ else:
606
+ raise ValueError(f"Unknown loss type: {loss_type}")
607
+
608
+ def _create_dataloader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
609
+ """Create a DataLoader from a dataset."""
610
+ from distil_trainer.data import DistillationCollator
611
+
612
+ batch_size = self.config.training_config.per_device_train_batch_size
613
+
614
+ # Get tokenizer from student model
615
+ tokenizer = None
616
+ if isinstance(self.student_model, SentenceTransformer):
617
+ tokenizer = self.student_model.tokenizer
618
+
619
+ collator = DistillationCollator(
620
+ tokenizer=tokenizer,
621
+ max_length=self.config.data_config.max_seq_length,
622
+ text_column=self.config.data_config.text_column,
623
+ )
624
+
625
+ return DataLoader(
626
+ dataset,
627
+ batch_size=batch_size,
628
+ shuffle=shuffle,
629
+ collate_fn=collator,
630
+ num_workers=self.config.data_config.num_workers,
631
+ pin_memory=True,
632
+ )
633
+
634
+ def _training_step(self, batch: dict[str, torch.Tensor], loss_fn: Callable) -> torch.Tensor:
635
+ """Perform a single training step."""
636
+ # Move batch to device
637
+ batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
638
+
639
+ # Get student embeddings
640
+ student_output = self.student_model(batch)
641
+
642
+ # SentenceTransformer returns a dict with 'sentence_embedding' key
643
+ if isinstance(student_output, dict) and "sentence_embedding" in student_output:
644
+ student_output = student_output["sentence_embedding"]
645
+
646
+ # Get teacher embeddings (from precomputed or compute on-the-fly)
647
+ if "label" in batch:
648
+ teacher_output = batch["label"]
649
+ else:
650
+ with torch.no_grad():
651
+ teacher_output = self.teacher_model(batch)
652
+ if self.teacher_projection is not None:
653
+ teacher_output = self.teacher_projection(teacher_output)
654
+
655
+ # Compute loss
656
+ loss = loss_fn(student_output, teacher_output)
657
+
658
+ # Backward pass
659
+ self.optimizer.zero_grad()
660
+ loss.backward()
661
+
662
+ # Gradient clipping
663
+ if self.config.training_config.max_grad_norm > 0:
664
+ torch.nn.utils.clip_grad_norm_(
665
+ self.student_model.parameters(),
666
+ self.config.training_config.max_grad_norm,
667
+ )
668
+
669
+ self.optimizer.step()
670
+ self.scheduler.step()
671
+
672
+ return loss
673
+
674
+ def evaluate(self) -> dict[str, float]:
675
+ """
676
+ Evaluate the student model.
677
+
678
+ Returns:
679
+ Dictionary of evaluation metrics.
680
+ """
681
+ if self.eval_dataset is None:
682
+ logger.warning("No evaluation dataset provided")
683
+ return {}
684
+
685
+ self.student_model.eval()
686
+ eval_dataloader = self._create_dataloader(self.eval_dataset, shuffle=False)
687
+
688
+ total_loss = 0.0
689
+ num_batches = 0
690
+ loss_fn = self._get_loss_function()
691
+
692
+ with torch.no_grad():
693
+ for batch in tqdm(eval_dataloader, desc="Evaluating"):
694
+ batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
695
+
696
+ student_output = self.student_model(batch)
697
+
698
+ # SentenceTransformer returns a dict with 'sentence_embedding' key
699
+ if isinstance(student_output, dict) and "sentence_embedding" in student_output:
700
+ student_output = student_output["sentence_embedding"]
701
+
702
+ if "label" in batch:
703
+ teacher_output = batch["label"]
704
+ else:
705
+ teacher_output = self.teacher_model(batch)
706
+ if self.teacher_projection is not None:
707
+ teacher_output = self.teacher_projection(teacher_output)
708
+
709
+ loss = loss_fn(student_output, teacher_output)
710
+ total_loss += loss.item()
711
+ num_batches += 1
712
+
713
+ self.student_model.train()
714
+
715
+ avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
716
+ return {"eval_loss": avg_loss}
717
+
718
+ def _is_better_metric(self, metrics: dict[str, float]) -> bool:
719
+ """Check if current metrics are better than the best."""
720
+ metric_name = self.config.training_config.metric_for_best_model
721
+ if metric_name not in metrics:
722
+ return False
723
+
724
+ current_value = metrics[metric_name]
725
+ is_better = current_value < self.best_metric
726
+
727
+ if not self.config.training_config.greater_is_better:
728
+ is_better = current_value < self.best_metric
729
+ else:
730
+ is_better = current_value > self.best_metric
731
+
732
+ if is_better:
733
+ self.best_metric = current_value
734
+
735
+ return is_better
736
+
737
+ def _save_checkpoint(self, name: str) -> None:
738
+ """Save a checkpoint."""
739
+ output_dir = Path(self.config.output_dir) / "checkpoints" / name
740
+ output_dir.mkdir(parents=True, exist_ok=True)
741
+
742
+ if isinstance(self.student_model, SentenceTransformer):
743
+ self.student_model.save(str(output_dir))
744
+ else:
745
+ self.student_model.save_pretrained(output_dir)
746
+
747
+ logger.info(f"Saved checkpoint: {output_dir}")
748
+
749
+ # Push to Hub Logic
750
+ if self.config.hub_config.push_to_hub and self.config.hub_config.push_to_hub_interval == "every_save":
751
+ self._push_to_hub_with_config(commit_message=f"Upload checkpoint {name}")
752
+
753
+ def _push_to_hub_with_config(self, commit_message: str = "Upload distilled model") -> None:
754
+ """Helper to push to hub using config settings."""
755
+ if not self.config.hub_config.push_to_hub:
756
+ return
757
+
758
+ repo_id = self.config.hub_config.hub_model_id
759
+ if not repo_id:
760
+ logger.warning("push_to_hub is True but hub_model_id is not set. Skipping push.")
761
+ return
762
+
763
+ try:
764
+ url = self.push_to_hub(
765
+ repo_id=repo_id,
766
+ private=self.config.hub_config.hub_private_repo,
767
+ commit_message=commit_message,
768
+ token=self.config.hub_config.hub_token,
769
+ )
770
+ logger.info(f"Pushed model to Hub: {url}")
771
+ except Exception as e:
772
+ logger.error(f"Failed to push to Hub: {e}")
773
+
774
+ def _load_checkpoint(self, name: str) -> None:
775
+ """Load a checkpoint."""
776
+ checkpoint_dir = Path(self.config.output_dir) / "checkpoints" / name
777
+
778
+ if not checkpoint_dir.exists():
779
+ logger.warning(f"Checkpoint not found: {checkpoint_dir}")
780
+ return
781
+
782
+ if isinstance(self.student_model, SentenceTransformer):
783
+ self.student_model = SentenceTransformer(str(checkpoint_dir))
784
+ else:
785
+ self.student_model = self.student_model.__class__.from_pretrained(checkpoint_dir)
786
+
787
+ self.student_model.to(self.device)
788
+ logger.info(f"Loaded checkpoint: {checkpoint_dir}")
789
+
790
+ def save_model(self, output_path: str | None = None) -> None:
791
+ """
792
+ Save the trained student model.
793
+
794
+ Args:
795
+ output_path: Path to save the model. Defaults to output_dir/final.
796
+ """
797
+ if output_path is None:
798
+ output_path = os.path.join(self.config.output_dir, "final")
799
+
800
+ output_dir = Path(output_path)
801
+ output_dir.mkdir(parents=True, exist_ok=True)
802
+
803
+ if isinstance(self.student_model, SentenceTransformer):
804
+ self.student_model.save(str(output_dir))
805
+ else:
806
+ self.student_model.save_pretrained(output_dir)
807
+
808
+ logger.info(f"Model saved to: {output_dir}")
809
+
810
+ def push_to_hub(
811
+ self,
812
+ repo_id: str,
813
+ private: bool = False,
814
+ commit_message: str = "Upload distilled model",
815
+ token: str | None = None,
816
+ ) -> str:
817
+ """
818
+ Push the model to HuggingFace Hub.
819
+
820
+ Args:
821
+ repo_id: Repository ID on HuggingFace Hub.
822
+ private: Whether the repository should be private.
823
+ commit_message: Commit message for the upload.
824
+ token: HuggingFace Hub token for authentication.
825
+
826
+ Returns:
827
+ URL of the uploaded model.
828
+ """
829
+ if isinstance(self.student_model, SentenceTransformer):
830
+ return self.student_model.push_to_hub(
831
+ repo_id=repo_id,
832
+ private=private,
833
+ commit_message=commit_message,
834
+ token=token,
835
+ exist_ok=True,
836
+ )
837
+ else:
838
+ return self.student_model.push_to_hub(
839
+ repo_id=repo_id,
840
+ private=private,
841
+ commit_message=commit_message,
842
+ token=token,
843
+ )