distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,843 @@
|
|
|
1
|
+
"""Main DistilTrainer class for knowledge distillation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from datasets import Dataset, DatasetDict, load_dataset
|
|
12
|
+
from sklearn.decomposition import PCA
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from transformers import (
|
|
16
|
+
PreTrainedModel,
|
|
17
|
+
PreTrainedTokenizer,
|
|
18
|
+
get_scheduler,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from sentence_transformers import SentenceTransformer
|
|
22
|
+
|
|
23
|
+
from distil_trainer.core.config import (
|
|
24
|
+
DistilTrainerConfig,
|
|
25
|
+
DistillationConfig,
|
|
26
|
+
LayerReductionConfig,
|
|
27
|
+
WidthPruningConfig,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class DistilTrainer:
|
|
34
|
+
"""
|
|
35
|
+
Main trainer class for knowledge distillation.
|
|
36
|
+
|
|
37
|
+
Supports multiple distillation strategies:
|
|
38
|
+
- Classical embedding distillation (MSE/Cosine loss)
|
|
39
|
+
- Layer reduction (depth pruning)
|
|
40
|
+
- Width pruning
|
|
41
|
+
- Combined pruning
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> config = DistilTrainerConfig(
|
|
45
|
+
... teacher_model="sentence-transformers/all-mpnet-base-v2",
|
|
46
|
+
... student_model="sentence-transformers/paraphrase-TinyBERT-L6-v2",
|
|
47
|
+
... output_dir="./distilled_model"
|
|
48
|
+
... )
|
|
49
|
+
>>> trainer = DistilTrainer(config)
|
|
50
|
+
>>> trainer.load_data(train_data="sentence-transformers/all-nli")
|
|
51
|
+
>>> trainer.train()
|
|
52
|
+
>>> trainer.save_model("./final_model")
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, config: DistilTrainerConfig):
|
|
56
|
+
"""
|
|
57
|
+
Initialize the DistilTrainer.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
config: Configuration for distillation training.
|
|
61
|
+
"""
|
|
62
|
+
self.config = config
|
|
63
|
+
self.device = self._get_device()
|
|
64
|
+
|
|
65
|
+
# Initialize models
|
|
66
|
+
self.teacher_model = self._load_teacher_model()
|
|
67
|
+
self.student_model = self._initialize_student_model()
|
|
68
|
+
|
|
69
|
+
# PCA for dimension reduction (if needed)
|
|
70
|
+
self.pca = None
|
|
71
|
+
self.teacher_projection = None
|
|
72
|
+
|
|
73
|
+
# Data
|
|
74
|
+
self.train_dataset: Dataset | None = None
|
|
75
|
+
self.eval_dataset: Dataset | None = None
|
|
76
|
+
self.test_dataset: Dataset | None = None
|
|
77
|
+
|
|
78
|
+
# Training state
|
|
79
|
+
self.optimizer = None
|
|
80
|
+
self.scheduler = None
|
|
81
|
+
self.global_step = 0
|
|
82
|
+
self.best_metric = float("inf")
|
|
83
|
+
|
|
84
|
+
logger.info(f"Initialized DistilTrainer with device: {self.device}")
|
|
85
|
+
logger.info(f"Teacher model: {self._get_model_info(self.teacher_model)}")
|
|
86
|
+
logger.info(f"Student model: {self._get_model_info(self.student_model)}")
|
|
87
|
+
|
|
88
|
+
def _get_device(self) -> torch.device:
|
|
89
|
+
"""Determine the device to use for training."""
|
|
90
|
+
if self.config.device == "auto":
|
|
91
|
+
if torch.cuda.is_available():
|
|
92
|
+
return torch.device("cuda")
|
|
93
|
+
elif torch.backends.mps.is_available():
|
|
94
|
+
return torch.device("mps")
|
|
95
|
+
else:
|
|
96
|
+
return torch.device("cpu")
|
|
97
|
+
return torch.device(self.config.device)
|
|
98
|
+
|
|
99
|
+
def _load_teacher_model(self) -> SentenceTransformer | PreTrainedModel:
|
|
100
|
+
"""Load the teacher model."""
|
|
101
|
+
teacher = self.config.teacher_model
|
|
102
|
+
|
|
103
|
+
if isinstance(teacher, str):
|
|
104
|
+
logger.info(f"Loading teacher model from: {teacher}")
|
|
105
|
+
teacher = SentenceTransformer(teacher)
|
|
106
|
+
|
|
107
|
+
teacher.to(self.device)
|
|
108
|
+
teacher.eval()
|
|
109
|
+
return teacher
|
|
110
|
+
|
|
111
|
+
def _initialize_student_model(self) -> SentenceTransformer | PreTrainedModel:
|
|
112
|
+
"""Initialize the student model based on the configured strategy."""
|
|
113
|
+
strategy = self.config.student_init_strategy
|
|
114
|
+
|
|
115
|
+
if strategy == "from_pretrained":
|
|
116
|
+
return self._load_pretrained_student()
|
|
117
|
+
elif strategy in ("layer_reduction", "depth_pruning"):
|
|
118
|
+
return self._create_layer_reduced_student()
|
|
119
|
+
elif strategy == "width_pruning":
|
|
120
|
+
return self._create_width_pruned_student()
|
|
121
|
+
elif strategy == "combined_pruning":
|
|
122
|
+
return self._create_combined_pruned_student()
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unknown student initialization strategy: {strategy}")
|
|
125
|
+
|
|
126
|
+
def _load_pretrained_student(self) -> SentenceTransformer | PreTrainedModel:
|
|
127
|
+
"""Load a pretrained student model."""
|
|
128
|
+
student = self.config.student_model
|
|
129
|
+
|
|
130
|
+
if student is None:
|
|
131
|
+
raise ValueError("student_model must be provided for 'from_pretrained' strategy")
|
|
132
|
+
|
|
133
|
+
if isinstance(student, str):
|
|
134
|
+
logger.info(f"Loading student model from: {student}")
|
|
135
|
+
student = SentenceTransformer(student)
|
|
136
|
+
|
|
137
|
+
student.to(self.device)
|
|
138
|
+
return student
|
|
139
|
+
|
|
140
|
+
def _create_layer_reduced_student(self) -> SentenceTransformer | PreTrainedModel:
|
|
141
|
+
"""Create a student model by removing layers from the teacher."""
|
|
142
|
+
from distil_trainer.pruning import DepthPruner
|
|
143
|
+
|
|
144
|
+
pruning_config = self.config.pruning_config
|
|
145
|
+
if not isinstance(pruning_config, LayerReductionConfig):
|
|
146
|
+
raise ValueError("pruning_config must be LayerReductionConfig for layer reduction")
|
|
147
|
+
|
|
148
|
+
logger.info("Creating layer-reduced student from teacher")
|
|
149
|
+
|
|
150
|
+
# Clone teacher model
|
|
151
|
+
student = SentenceTransformer(self.config.teacher_model)
|
|
152
|
+
|
|
153
|
+
# Apply depth pruning
|
|
154
|
+
pruner = DepthPruner(student)
|
|
155
|
+
student = pruner.prune(
|
|
156
|
+
layers_to_keep=pruning_config.layers_to_keep,
|
|
157
|
+
num_layers_to_keep=pruning_config.num_layers_to_keep,
|
|
158
|
+
layers_to_drop=pruning_config.layers_to_drop,
|
|
159
|
+
layer_selection=pruning_config.layer_selection,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
student.to(self.device)
|
|
163
|
+
return student
|
|
164
|
+
|
|
165
|
+
def _create_width_pruned_student(self) -> SentenceTransformer | PreTrainedModel:
|
|
166
|
+
"""Create a student model by pruning width dimensions."""
|
|
167
|
+
from distil_trainer.pruning import WidthPruner
|
|
168
|
+
|
|
169
|
+
pruning_config = self.config.pruning_config
|
|
170
|
+
if not isinstance(pruning_config, WidthPruningConfig):
|
|
171
|
+
raise ValueError("pruning_config must be WidthPruningConfig for width pruning")
|
|
172
|
+
|
|
173
|
+
logger.info("Creating width-pruned student from teacher")
|
|
174
|
+
|
|
175
|
+
# Clone teacher model
|
|
176
|
+
student = SentenceTransformer(self.config.teacher_model)
|
|
177
|
+
|
|
178
|
+
# Apply width pruning
|
|
179
|
+
pruner = WidthPruner(student)
|
|
180
|
+
student = pruner.prune(pruning_config)
|
|
181
|
+
|
|
182
|
+
student.to(self.device)
|
|
183
|
+
return student
|
|
184
|
+
|
|
185
|
+
def _create_combined_pruned_student(self) -> SentenceTransformer | PreTrainedModel:
|
|
186
|
+
"""Create a student model using both depth and width pruning."""
|
|
187
|
+
from distil_trainer.pruning import CombinedPruner
|
|
188
|
+
|
|
189
|
+
pruning_config = self.config.pruning_config
|
|
190
|
+
|
|
191
|
+
logger.info("Creating combined-pruned student from teacher")
|
|
192
|
+
|
|
193
|
+
# Clone teacher model
|
|
194
|
+
student = SentenceTransformer(self.config.teacher_model)
|
|
195
|
+
|
|
196
|
+
# Apply combined pruning
|
|
197
|
+
pruner = CombinedPruner(student)
|
|
198
|
+
student = pruner.prune(pruning_config)
|
|
199
|
+
|
|
200
|
+
student.to(self.device)
|
|
201
|
+
return student
|
|
202
|
+
|
|
203
|
+
def _get_model_info(self, model: SentenceTransformer | PreTrainedModel) -> str:
|
|
204
|
+
"""Get a string representation of model info."""
|
|
205
|
+
if isinstance(model, SentenceTransformer):
|
|
206
|
+
num_params = sum(p.numel() for p in model.parameters())
|
|
207
|
+
embedding_dim = model.get_sentence_embedding_dimension()
|
|
208
|
+
return f"SentenceTransformer(params={num_params:,}, dim={embedding_dim})"
|
|
209
|
+
else:
|
|
210
|
+
num_params = sum(p.numel() for p in model.parameters())
|
|
211
|
+
return f"PreTrainedModel(params={num_params:,})"
|
|
212
|
+
|
|
213
|
+
def load_data(
|
|
214
|
+
self,
|
|
215
|
+
train_data: str | Dataset | None = None,
|
|
216
|
+
eval_data: str | Dataset | None = None,
|
|
217
|
+
test_data: str | Dataset | None = None,
|
|
218
|
+
text_column: str | None = None,
|
|
219
|
+
max_samples: int | None = None,
|
|
220
|
+
) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Load training, evaluation, and test datasets.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
train_data: Path or name of training dataset, or Dataset object.
|
|
226
|
+
eval_data: Path or name of evaluation dataset, or Dataset object.
|
|
227
|
+
test_data: Path or name of test dataset, or Dataset object.
|
|
228
|
+
text_column: Name of the column containing text/sentences.
|
|
229
|
+
Overrides config.data_config.text_column if provided.
|
|
230
|
+
max_samples: Maximum number of samples to use from the dataset.
|
|
231
|
+
Useful for quick testing. Overrides config.data_config.max_samples.
|
|
232
|
+
"""
|
|
233
|
+
if text_column is not None:
|
|
234
|
+
self.config.data_config.text_column = text_column
|
|
235
|
+
logger.info(f"Set text_column to: {text_column}")
|
|
236
|
+
|
|
237
|
+
if max_samples is not None:
|
|
238
|
+
self.config.data_config.max_samples = max_samples
|
|
239
|
+
logger.info(f"Set max_samples to: {max_samples}")
|
|
240
|
+
|
|
241
|
+
if train_data is not None:
|
|
242
|
+
self.train_dataset = self._load_dataset(train_data, "train")
|
|
243
|
+
logger.info(f"Loaded training dataset: {len(self.train_dataset)} samples")
|
|
244
|
+
|
|
245
|
+
if eval_data is not None:
|
|
246
|
+
self.eval_dataset = self._load_dataset(eval_data, "validation")
|
|
247
|
+
logger.info(f"Loaded evaluation dataset: {len(self.eval_dataset)} samples")
|
|
248
|
+
|
|
249
|
+
if test_data is not None:
|
|
250
|
+
self.test_dataset = self._load_dataset(test_data, "test")
|
|
251
|
+
logger.info(f"Loaded test dataset: {len(self.test_dataset)} samples")
|
|
252
|
+
|
|
253
|
+
def _load_dataset(self, data: str | Dataset, split: str = "train") -> Dataset:
|
|
254
|
+
"""Load a dataset from a path or name."""
|
|
255
|
+
if isinstance(data, Dataset):
|
|
256
|
+
dataset = data
|
|
257
|
+
else:
|
|
258
|
+
logger.info(f"Loading dataset: {data}")
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
dataset = load_dataset(data, split=split)
|
|
262
|
+
except Exception:
|
|
263
|
+
# Try loading as a DatasetDict and getting the split
|
|
264
|
+
dataset_dict = load_dataset(data)
|
|
265
|
+
if isinstance(dataset_dict, DatasetDict):
|
|
266
|
+
if split in dataset_dict:
|
|
267
|
+
dataset = dataset_dict[split]
|
|
268
|
+
else:
|
|
269
|
+
# Use the first available split
|
|
270
|
+
dataset = list(dataset_dict.values())[0]
|
|
271
|
+
else:
|
|
272
|
+
dataset = dataset_dict
|
|
273
|
+
|
|
274
|
+
# Apply max_samples limit if configured
|
|
275
|
+
max_samples = self.config.data_config.max_samples
|
|
276
|
+
if max_samples is not None and max_samples > 0:
|
|
277
|
+
original_size = len(dataset)
|
|
278
|
+
if max_samples < original_size:
|
|
279
|
+
dataset = dataset.select(range(max_samples))
|
|
280
|
+
logger.info(f"Limited dataset from {original_size} to {max_samples} samples")
|
|
281
|
+
|
|
282
|
+
return dataset
|
|
283
|
+
|
|
284
|
+
def setup_pca_projection(self) -> None:
|
|
285
|
+
"""Set up PCA projection if student dimension is smaller than teacher."""
|
|
286
|
+
if not isinstance(self.teacher_model, SentenceTransformer):
|
|
287
|
+
return
|
|
288
|
+
if not isinstance(self.student_model, SentenceTransformer):
|
|
289
|
+
return
|
|
290
|
+
|
|
291
|
+
teacher_dim = self.teacher_model.get_sentence_embedding_dimension()
|
|
292
|
+
student_dim = self.student_model.get_sentence_embedding_dimension()
|
|
293
|
+
|
|
294
|
+
if student_dim >= teacher_dim:
|
|
295
|
+
logger.info("Student dimension >= teacher dimension, no PCA needed")
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
logger.info(f"Setting up PCA projection: {teacher_dim} -> {student_dim}")
|
|
299
|
+
|
|
300
|
+
# Collect sample sentences for PCA
|
|
301
|
+
if self.train_dataset is None:
|
|
302
|
+
raise ValueError("Training dataset required for PCA projection")
|
|
303
|
+
|
|
304
|
+
text_column = self.config.data_config.text_column
|
|
305
|
+
num_samples = min(
|
|
306
|
+
self.config.distillation_config.pca_num_samples,
|
|
307
|
+
len(self.train_dataset),
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
sample_sentences = self.train_dataset[:num_samples][text_column]
|
|
311
|
+
|
|
312
|
+
# Compute teacher embeddings
|
|
313
|
+
logger.info(f"Computing teacher embeddings for {num_samples} samples")
|
|
314
|
+
with torch.no_grad():
|
|
315
|
+
embeddings = self.teacher_model.encode(
|
|
316
|
+
sample_sentences,
|
|
317
|
+
convert_to_numpy=True,
|
|
318
|
+
show_progress_bar=True,
|
|
319
|
+
batch_size=self.config.distillation_config.teacher_inference_batch_size,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Fit PCA
|
|
323
|
+
logger.info("Fitting PCA...")
|
|
324
|
+
self.pca = PCA(n_components=student_dim)
|
|
325
|
+
self.pca.fit(embeddings)
|
|
326
|
+
|
|
327
|
+
# Create projection layer for teacher
|
|
328
|
+
from distil_trainer.models import DenseProjection
|
|
329
|
+
|
|
330
|
+
self.teacher_projection = DenseProjection(
|
|
331
|
+
in_features=teacher_dim,
|
|
332
|
+
out_features=student_dim,
|
|
333
|
+
weights=torch.tensor(self.pca.components_, dtype=torch.float32),
|
|
334
|
+
)
|
|
335
|
+
self.teacher_projection.to(self.device)
|
|
336
|
+
|
|
337
|
+
logger.info(f"PCA projection ready: explained variance ratio = {sum(self.pca.explained_variance_ratio_):.4f}")
|
|
338
|
+
|
|
339
|
+
def precompute_teacher_embeddings(self) -> None:
|
|
340
|
+
"""Precompute teacher embeddings for the training dataset."""
|
|
341
|
+
if not self.config.distillation_config.precompute_teacher_embeddings:
|
|
342
|
+
return
|
|
343
|
+
|
|
344
|
+
if self.train_dataset is None:
|
|
345
|
+
raise ValueError("Training dataset required")
|
|
346
|
+
|
|
347
|
+
logger.info("Precomputing teacher embeddings...")
|
|
348
|
+
|
|
349
|
+
text_column = self.config.data_config.text_column
|
|
350
|
+
sentences = self.train_dataset[text_column]
|
|
351
|
+
|
|
352
|
+
batch_size = self.config.distillation_config.teacher_inference_batch_size
|
|
353
|
+
|
|
354
|
+
with torch.no_grad():
|
|
355
|
+
embeddings = self.teacher_model.encode(
|
|
356
|
+
sentences,
|
|
357
|
+
convert_to_numpy=False,
|
|
358
|
+
convert_to_tensor=True,
|
|
359
|
+
show_progress_bar=True,
|
|
360
|
+
batch_size=batch_size,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Apply projection if needed
|
|
364
|
+
if self.teacher_projection is not None:
|
|
365
|
+
embeddings = self.teacher_projection(embeddings)
|
|
366
|
+
|
|
367
|
+
# Add embeddings to dataset
|
|
368
|
+
if isinstance(embeddings, torch.Tensor):
|
|
369
|
+
embeddings_list = embeddings.cpu().tolist()
|
|
370
|
+
else:
|
|
371
|
+
# Already a list (e.g., when encode returns list directly)
|
|
372
|
+
embeddings_list = embeddings
|
|
373
|
+
self.train_dataset = self.train_dataset.add_column("label", embeddings_list)
|
|
374
|
+
|
|
375
|
+
logger.info("Teacher embeddings precomputed and cached")
|
|
376
|
+
|
|
377
|
+
def train(self) -> dict[str, float]:
|
|
378
|
+
"""
|
|
379
|
+
Run the distillation training.
|
|
380
|
+
|
|
381
|
+
Returns:
|
|
382
|
+
Dictionary of training metrics.
|
|
383
|
+
"""
|
|
384
|
+
if self.train_dataset is None:
|
|
385
|
+
raise ValueError("Training dataset required. Call load_data() first.")
|
|
386
|
+
|
|
387
|
+
logger.info("Starting distillation training...")
|
|
388
|
+
|
|
389
|
+
# Setup WandB
|
|
390
|
+
is_wandb_avail = False
|
|
391
|
+
if "wandb" in self.config.training_config.report_to or (self.config.wandb_config.project is not None):
|
|
392
|
+
try:
|
|
393
|
+
import wandb
|
|
394
|
+
from dataclasses import asdict
|
|
395
|
+
|
|
396
|
+
# Check if already initialized
|
|
397
|
+
if wandb.run is None:
|
|
398
|
+
wandb.init(
|
|
399
|
+
project=self.config.wandb_config.project,
|
|
400
|
+
entity=self.config.wandb_config.entity,
|
|
401
|
+
name=self.config.wandb_config.name or self.config.training_config.run_name,
|
|
402
|
+
tags=self.config.wandb_config.tags,
|
|
403
|
+
group=self.config.wandb_config.group,
|
|
404
|
+
notes=self.config.wandb_config.notes,
|
|
405
|
+
config=asdict(self.config),
|
|
406
|
+
)
|
|
407
|
+
is_wandb_avail = True
|
|
408
|
+
except ImportError:
|
|
409
|
+
logger.warning("wandb not installed, skipping logging")
|
|
410
|
+
|
|
411
|
+
# Setup PCA if needed
|
|
412
|
+
if self.config.distillation_config.use_pca_projection:
|
|
413
|
+
self.setup_pca_projection()
|
|
414
|
+
|
|
415
|
+
# Precompute teacher embeddings if enabled
|
|
416
|
+
self.precompute_teacher_embeddings()
|
|
417
|
+
|
|
418
|
+
# Setup optimizer and scheduler
|
|
419
|
+
self._setup_optimizer()
|
|
420
|
+
|
|
421
|
+
# Get loss function
|
|
422
|
+
loss_fn = self._get_loss_function()
|
|
423
|
+
|
|
424
|
+
# Create data loader
|
|
425
|
+
train_dataloader = self._create_dataloader(self.train_dataset, shuffle=True)
|
|
426
|
+
|
|
427
|
+
# Training loop
|
|
428
|
+
training_config = self.config.training_config
|
|
429
|
+
num_epochs = training_config.num_train_epochs
|
|
430
|
+
total_steps = len(train_dataloader) * num_epochs
|
|
431
|
+
|
|
432
|
+
if training_config.max_steps > 0:
|
|
433
|
+
total_steps = min(total_steps, training_config.max_steps)
|
|
434
|
+
|
|
435
|
+
logger.info(f"Training for {num_epochs} epochs, {total_steps} total steps")
|
|
436
|
+
|
|
437
|
+
self.student_model.train()
|
|
438
|
+
self.global_step = 0
|
|
439
|
+
|
|
440
|
+
avg_epoch_loss = 0.0
|
|
441
|
+
|
|
442
|
+
for epoch in range(num_epochs):
|
|
443
|
+
epoch_loss = 0.0
|
|
444
|
+
num_batches = 0
|
|
445
|
+
|
|
446
|
+
progress_bar = tqdm(
|
|
447
|
+
train_dataloader,
|
|
448
|
+
desc=f"Epoch {epoch + 1}/{num_epochs}",
|
|
449
|
+
disable=False,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
for batch in progress_bar:
|
|
453
|
+
loss = self._training_step(batch, loss_fn)
|
|
454
|
+
|
|
455
|
+
epoch_loss += loss.item()
|
|
456
|
+
num_batches += 1
|
|
457
|
+
self.global_step += 1
|
|
458
|
+
|
|
459
|
+
current_loss = loss.item()
|
|
460
|
+
progress_bar.set_postfix({"loss": f"{current_loss:.4f}"})
|
|
461
|
+
|
|
462
|
+
# Logging
|
|
463
|
+
if self.global_step % training_config.logging_steps == 0:
|
|
464
|
+
avg_loss = epoch_loss / num_batches
|
|
465
|
+
logger.info(f"Step {self.global_step}: loss = {avg_loss:.4f}")
|
|
466
|
+
|
|
467
|
+
if is_wandb_avail:
|
|
468
|
+
wandb.log(
|
|
469
|
+
{
|
|
470
|
+
"train/loss": current_loss,
|
|
471
|
+
"train/avg_loss": avg_loss,
|
|
472
|
+
"train/epoch": epoch + (num_batches / len(train_dataloader)),
|
|
473
|
+
"train/learning_rate": self.scheduler.get_last_lr()[0],
|
|
474
|
+
},
|
|
475
|
+
step=self.global_step
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Evaluation
|
|
479
|
+
if (
|
|
480
|
+
training_config.eval_strategy == "steps"
|
|
481
|
+
and self.global_step % training_config.eval_steps == 0
|
|
482
|
+
):
|
|
483
|
+
eval_metrics = self.evaluate()
|
|
484
|
+
logger.info(f"Step {self.global_step}: {eval_metrics}")
|
|
485
|
+
|
|
486
|
+
if is_wandb_avail:
|
|
487
|
+
wandb_metrics = {f"eval/{k}": v for k, v in eval_metrics.items()}
|
|
488
|
+
wandb.log(wandb_metrics, step=self.global_step)
|
|
489
|
+
|
|
490
|
+
# Save best model
|
|
491
|
+
if self._is_better_metric(eval_metrics):
|
|
492
|
+
self._save_checkpoint("best")
|
|
493
|
+
|
|
494
|
+
# Save checkpoint
|
|
495
|
+
if self.global_step % training_config.save_steps == 0:
|
|
496
|
+
self._save_checkpoint(f"checkpoint-{self.global_step}")
|
|
497
|
+
|
|
498
|
+
# Check max steps
|
|
499
|
+
if training_config.max_steps > 0 and self.global_step >= training_config.max_steps:
|
|
500
|
+
break
|
|
501
|
+
|
|
502
|
+
# End of epoch
|
|
503
|
+
avg_epoch_loss = epoch_loss / num_batches
|
|
504
|
+
logger.info(f"Epoch {epoch + 1} completed: avg_loss = {avg_epoch_loss:.4f}")
|
|
505
|
+
|
|
506
|
+
if is_wandb_avail:
|
|
507
|
+
wandb.log({"train/epoch_loss": avg_epoch_loss}, step=self.global_step)
|
|
508
|
+
|
|
509
|
+
if training_config.max_steps > 0 and self.global_step >= training_config.max_steps:
|
|
510
|
+
break
|
|
511
|
+
|
|
512
|
+
logger.info("Training completed!")
|
|
513
|
+
|
|
514
|
+
# Load best model if configured
|
|
515
|
+
if training_config.load_best_model_at_end:
|
|
516
|
+
self._load_checkpoint("best")
|
|
517
|
+
|
|
518
|
+
# Push to Hub at end if configured
|
|
519
|
+
if self.config.hub_config.push_to_hub:
|
|
520
|
+
self._push_to_hub_with_config()
|
|
521
|
+
|
|
522
|
+
if is_wandb_avail:
|
|
523
|
+
wandb.finish()
|
|
524
|
+
|
|
525
|
+
return {"train_loss": avg_epoch_loss}
|
|
526
|
+
|
|
527
|
+
def _setup_optimizer(self) -> None:
|
|
528
|
+
"""Set up optimizer and learning rate scheduler."""
|
|
529
|
+
training_config = self.config.training_config
|
|
530
|
+
|
|
531
|
+
# Optimizer
|
|
532
|
+
if training_config.optimizer == "adamw":
|
|
533
|
+
self.optimizer = torch.optim.AdamW(
|
|
534
|
+
self.student_model.parameters(),
|
|
535
|
+
lr=training_config.learning_rate,
|
|
536
|
+
betas=(training_config.adam_beta1, training_config.adam_beta2),
|
|
537
|
+
eps=training_config.adam_epsilon,
|
|
538
|
+
weight_decay=training_config.weight_decay,
|
|
539
|
+
)
|
|
540
|
+
elif training_config.optimizer == "adam":
|
|
541
|
+
self.optimizer = torch.optim.Adam(
|
|
542
|
+
self.student_model.parameters(),
|
|
543
|
+
lr=training_config.learning_rate,
|
|
544
|
+
)
|
|
545
|
+
elif training_config.optimizer == "sgd":
|
|
546
|
+
self.optimizer = torch.optim.SGD(
|
|
547
|
+
self.student_model.parameters(),
|
|
548
|
+
lr=training_config.learning_rate,
|
|
549
|
+
weight_decay=training_config.weight_decay,
|
|
550
|
+
)
|
|
551
|
+
else:
|
|
552
|
+
raise ValueError(f"Unknown optimizer: {training_config.optimizer}")
|
|
553
|
+
|
|
554
|
+
# Scheduler
|
|
555
|
+
num_training_steps = self._get_num_training_steps()
|
|
556
|
+
warmup_steps = training_config.warmup_steps
|
|
557
|
+
if warmup_steps == 0 and training_config.warmup_ratio > 0:
|
|
558
|
+
warmup_steps = int(num_training_steps * training_config.warmup_ratio)
|
|
559
|
+
|
|
560
|
+
self.scheduler = get_scheduler(
|
|
561
|
+
training_config.lr_scheduler_type,
|
|
562
|
+
optimizer=self.optimizer,
|
|
563
|
+
num_warmup_steps=warmup_steps,
|
|
564
|
+
num_training_steps=num_training_steps,
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
def _get_num_training_steps(self) -> int:
|
|
568
|
+
"""Calculate the total number of training steps."""
|
|
569
|
+
if self.train_dataset is None:
|
|
570
|
+
return 0
|
|
571
|
+
|
|
572
|
+
training_config = self.config.training_config
|
|
573
|
+
num_batches = len(self.train_dataset) // training_config.per_device_train_batch_size
|
|
574
|
+
total_steps = num_batches * training_config.num_train_epochs
|
|
575
|
+
|
|
576
|
+
if training_config.max_steps > 0:
|
|
577
|
+
total_steps = min(total_steps, training_config.max_steps)
|
|
578
|
+
|
|
579
|
+
return total_steps
|
|
580
|
+
|
|
581
|
+
def _get_loss_function(self) -> Callable:
|
|
582
|
+
"""Get the loss function based on configuration."""
|
|
583
|
+
from distil_trainer.distillation import DistillationLosses
|
|
584
|
+
|
|
585
|
+
loss_type = self.config.distillation_config.loss_type
|
|
586
|
+
|
|
587
|
+
if loss_type == "mse":
|
|
588
|
+
return DistillationLosses.mse_loss
|
|
589
|
+
elif loss_type == "cosine":
|
|
590
|
+
return DistillationLosses.cosine_loss
|
|
591
|
+
elif loss_type == "kl_divergence":
|
|
592
|
+
temperature = self.config.distillation_config.temperature
|
|
593
|
+
return lambda s, t: DistillationLosses.kl_divergence_loss(s, t, temperature)
|
|
594
|
+
elif loss_type == "combined":
|
|
595
|
+
from distil_trainer.distillation import CombinedDistillationLoss
|
|
596
|
+
|
|
597
|
+
return CombinedDistillationLoss(
|
|
598
|
+
logit_weight=self.config.distillation_config.logit_loss_weight,
|
|
599
|
+
embedding_weight=self.config.distillation_config.embedding_loss_weight,
|
|
600
|
+
intermediate_weight=self.config.distillation_config.intermediate_loss_weight,
|
|
601
|
+
attention_weight=self.config.distillation_config.attention_loss_weight,
|
|
602
|
+
temperature=self.config.distillation_config.temperature,
|
|
603
|
+
layer_mapping=self.config.distillation_config.layer_mapping,
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
raise ValueError(f"Unknown loss type: {loss_type}")
|
|
607
|
+
|
|
608
|
+
def _create_dataloader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
|
|
609
|
+
"""Create a DataLoader from a dataset."""
|
|
610
|
+
from distil_trainer.data import DistillationCollator
|
|
611
|
+
|
|
612
|
+
batch_size = self.config.training_config.per_device_train_batch_size
|
|
613
|
+
|
|
614
|
+
# Get tokenizer from student model
|
|
615
|
+
tokenizer = None
|
|
616
|
+
if isinstance(self.student_model, SentenceTransformer):
|
|
617
|
+
tokenizer = self.student_model.tokenizer
|
|
618
|
+
|
|
619
|
+
collator = DistillationCollator(
|
|
620
|
+
tokenizer=tokenizer,
|
|
621
|
+
max_length=self.config.data_config.max_seq_length,
|
|
622
|
+
text_column=self.config.data_config.text_column,
|
|
623
|
+
)
|
|
624
|
+
|
|
625
|
+
return DataLoader(
|
|
626
|
+
dataset,
|
|
627
|
+
batch_size=batch_size,
|
|
628
|
+
shuffle=shuffle,
|
|
629
|
+
collate_fn=collator,
|
|
630
|
+
num_workers=self.config.data_config.num_workers,
|
|
631
|
+
pin_memory=True,
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
def _training_step(self, batch: dict[str, torch.Tensor], loss_fn: Callable) -> torch.Tensor:
|
|
635
|
+
"""Perform a single training step."""
|
|
636
|
+
# Move batch to device
|
|
637
|
+
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
|
638
|
+
|
|
639
|
+
# Get student embeddings
|
|
640
|
+
student_output = self.student_model(batch)
|
|
641
|
+
|
|
642
|
+
# SentenceTransformer returns a dict with 'sentence_embedding' key
|
|
643
|
+
if isinstance(student_output, dict) and "sentence_embedding" in student_output:
|
|
644
|
+
student_output = student_output["sentence_embedding"]
|
|
645
|
+
|
|
646
|
+
# Get teacher embeddings (from precomputed or compute on-the-fly)
|
|
647
|
+
if "label" in batch:
|
|
648
|
+
teacher_output = batch["label"]
|
|
649
|
+
else:
|
|
650
|
+
with torch.no_grad():
|
|
651
|
+
teacher_output = self.teacher_model(batch)
|
|
652
|
+
if self.teacher_projection is not None:
|
|
653
|
+
teacher_output = self.teacher_projection(teacher_output)
|
|
654
|
+
|
|
655
|
+
# Compute loss
|
|
656
|
+
loss = loss_fn(student_output, teacher_output)
|
|
657
|
+
|
|
658
|
+
# Backward pass
|
|
659
|
+
self.optimizer.zero_grad()
|
|
660
|
+
loss.backward()
|
|
661
|
+
|
|
662
|
+
# Gradient clipping
|
|
663
|
+
if self.config.training_config.max_grad_norm > 0:
|
|
664
|
+
torch.nn.utils.clip_grad_norm_(
|
|
665
|
+
self.student_model.parameters(),
|
|
666
|
+
self.config.training_config.max_grad_norm,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
self.optimizer.step()
|
|
670
|
+
self.scheduler.step()
|
|
671
|
+
|
|
672
|
+
return loss
|
|
673
|
+
|
|
674
|
+
def evaluate(self) -> dict[str, float]:
|
|
675
|
+
"""
|
|
676
|
+
Evaluate the student model.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
Dictionary of evaluation metrics.
|
|
680
|
+
"""
|
|
681
|
+
if self.eval_dataset is None:
|
|
682
|
+
logger.warning("No evaluation dataset provided")
|
|
683
|
+
return {}
|
|
684
|
+
|
|
685
|
+
self.student_model.eval()
|
|
686
|
+
eval_dataloader = self._create_dataloader(self.eval_dataset, shuffle=False)
|
|
687
|
+
|
|
688
|
+
total_loss = 0.0
|
|
689
|
+
num_batches = 0
|
|
690
|
+
loss_fn = self._get_loss_function()
|
|
691
|
+
|
|
692
|
+
with torch.no_grad():
|
|
693
|
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
|
694
|
+
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
|
695
|
+
|
|
696
|
+
student_output = self.student_model(batch)
|
|
697
|
+
|
|
698
|
+
# SentenceTransformer returns a dict with 'sentence_embedding' key
|
|
699
|
+
if isinstance(student_output, dict) and "sentence_embedding" in student_output:
|
|
700
|
+
student_output = student_output["sentence_embedding"]
|
|
701
|
+
|
|
702
|
+
if "label" in batch:
|
|
703
|
+
teacher_output = batch["label"]
|
|
704
|
+
else:
|
|
705
|
+
teacher_output = self.teacher_model(batch)
|
|
706
|
+
if self.teacher_projection is not None:
|
|
707
|
+
teacher_output = self.teacher_projection(teacher_output)
|
|
708
|
+
|
|
709
|
+
loss = loss_fn(student_output, teacher_output)
|
|
710
|
+
total_loss += loss.item()
|
|
711
|
+
num_batches += 1
|
|
712
|
+
|
|
713
|
+
self.student_model.train()
|
|
714
|
+
|
|
715
|
+
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
|
716
|
+
return {"eval_loss": avg_loss}
|
|
717
|
+
|
|
718
|
+
def _is_better_metric(self, metrics: dict[str, float]) -> bool:
|
|
719
|
+
"""Check if current metrics are better than the best."""
|
|
720
|
+
metric_name = self.config.training_config.metric_for_best_model
|
|
721
|
+
if metric_name not in metrics:
|
|
722
|
+
return False
|
|
723
|
+
|
|
724
|
+
current_value = metrics[metric_name]
|
|
725
|
+
is_better = current_value < self.best_metric
|
|
726
|
+
|
|
727
|
+
if not self.config.training_config.greater_is_better:
|
|
728
|
+
is_better = current_value < self.best_metric
|
|
729
|
+
else:
|
|
730
|
+
is_better = current_value > self.best_metric
|
|
731
|
+
|
|
732
|
+
if is_better:
|
|
733
|
+
self.best_metric = current_value
|
|
734
|
+
|
|
735
|
+
return is_better
|
|
736
|
+
|
|
737
|
+
def _save_checkpoint(self, name: str) -> None:
|
|
738
|
+
"""Save a checkpoint."""
|
|
739
|
+
output_dir = Path(self.config.output_dir) / "checkpoints" / name
|
|
740
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
741
|
+
|
|
742
|
+
if isinstance(self.student_model, SentenceTransformer):
|
|
743
|
+
self.student_model.save(str(output_dir))
|
|
744
|
+
else:
|
|
745
|
+
self.student_model.save_pretrained(output_dir)
|
|
746
|
+
|
|
747
|
+
logger.info(f"Saved checkpoint: {output_dir}")
|
|
748
|
+
|
|
749
|
+
# Push to Hub Logic
|
|
750
|
+
if self.config.hub_config.push_to_hub and self.config.hub_config.push_to_hub_interval == "every_save":
|
|
751
|
+
self._push_to_hub_with_config(commit_message=f"Upload checkpoint {name}")
|
|
752
|
+
|
|
753
|
+
def _push_to_hub_with_config(self, commit_message: str = "Upload distilled model") -> None:
|
|
754
|
+
"""Helper to push to hub using config settings."""
|
|
755
|
+
if not self.config.hub_config.push_to_hub:
|
|
756
|
+
return
|
|
757
|
+
|
|
758
|
+
repo_id = self.config.hub_config.hub_model_id
|
|
759
|
+
if not repo_id:
|
|
760
|
+
logger.warning("push_to_hub is True but hub_model_id is not set. Skipping push.")
|
|
761
|
+
return
|
|
762
|
+
|
|
763
|
+
try:
|
|
764
|
+
url = self.push_to_hub(
|
|
765
|
+
repo_id=repo_id,
|
|
766
|
+
private=self.config.hub_config.hub_private_repo,
|
|
767
|
+
commit_message=commit_message,
|
|
768
|
+
token=self.config.hub_config.hub_token,
|
|
769
|
+
)
|
|
770
|
+
logger.info(f"Pushed model to Hub: {url}")
|
|
771
|
+
except Exception as e:
|
|
772
|
+
logger.error(f"Failed to push to Hub: {e}")
|
|
773
|
+
|
|
774
|
+
def _load_checkpoint(self, name: str) -> None:
|
|
775
|
+
"""Load a checkpoint."""
|
|
776
|
+
checkpoint_dir = Path(self.config.output_dir) / "checkpoints" / name
|
|
777
|
+
|
|
778
|
+
if not checkpoint_dir.exists():
|
|
779
|
+
logger.warning(f"Checkpoint not found: {checkpoint_dir}")
|
|
780
|
+
return
|
|
781
|
+
|
|
782
|
+
if isinstance(self.student_model, SentenceTransformer):
|
|
783
|
+
self.student_model = SentenceTransformer(str(checkpoint_dir))
|
|
784
|
+
else:
|
|
785
|
+
self.student_model = self.student_model.__class__.from_pretrained(checkpoint_dir)
|
|
786
|
+
|
|
787
|
+
self.student_model.to(self.device)
|
|
788
|
+
logger.info(f"Loaded checkpoint: {checkpoint_dir}")
|
|
789
|
+
|
|
790
|
+
def save_model(self, output_path: str | None = None) -> None:
|
|
791
|
+
"""
|
|
792
|
+
Save the trained student model.
|
|
793
|
+
|
|
794
|
+
Args:
|
|
795
|
+
output_path: Path to save the model. Defaults to output_dir/final.
|
|
796
|
+
"""
|
|
797
|
+
if output_path is None:
|
|
798
|
+
output_path = os.path.join(self.config.output_dir, "final")
|
|
799
|
+
|
|
800
|
+
output_dir = Path(output_path)
|
|
801
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
802
|
+
|
|
803
|
+
if isinstance(self.student_model, SentenceTransformer):
|
|
804
|
+
self.student_model.save(str(output_dir))
|
|
805
|
+
else:
|
|
806
|
+
self.student_model.save_pretrained(output_dir)
|
|
807
|
+
|
|
808
|
+
logger.info(f"Model saved to: {output_dir}")
|
|
809
|
+
|
|
810
|
+
def push_to_hub(
|
|
811
|
+
self,
|
|
812
|
+
repo_id: str,
|
|
813
|
+
private: bool = False,
|
|
814
|
+
commit_message: str = "Upload distilled model",
|
|
815
|
+
token: str | None = None,
|
|
816
|
+
) -> str:
|
|
817
|
+
"""
|
|
818
|
+
Push the model to HuggingFace Hub.
|
|
819
|
+
|
|
820
|
+
Args:
|
|
821
|
+
repo_id: Repository ID on HuggingFace Hub.
|
|
822
|
+
private: Whether the repository should be private.
|
|
823
|
+
commit_message: Commit message for the upload.
|
|
824
|
+
token: HuggingFace Hub token for authentication.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
URL of the uploaded model.
|
|
828
|
+
"""
|
|
829
|
+
if isinstance(self.student_model, SentenceTransformer):
|
|
830
|
+
return self.student_model.push_to_hub(
|
|
831
|
+
repo_id=repo_id,
|
|
832
|
+
private=private,
|
|
833
|
+
commit_message=commit_message,
|
|
834
|
+
token=token,
|
|
835
|
+
exist_ok=True,
|
|
836
|
+
)
|
|
837
|
+
else:
|
|
838
|
+
return self.student_model.push_to_hub(
|
|
839
|
+
repo_id=repo_id,
|
|
840
|
+
private=private,
|
|
841
|
+
commit_message=commit_message,
|
|
842
|
+
token=token,
|
|
843
|
+
)
|