distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""Multilingual distillation for extending models to new languages."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from datasets import DatasetDict, load_dataset
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
from distil_trainer.distillation.losses import DistillationLosses
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class MultilingualDistillationStrategy:
|
|
22
|
+
"""
|
|
23
|
+
Strategy for extending a monolingual model to multiple languages.
|
|
24
|
+
|
|
25
|
+
Uses parallel sentences to train a multilingual student where:
|
|
26
|
+
- Student source language embeddings match teacher source embeddings
|
|
27
|
+
- Student target language embeddings match teacher source embeddings
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> strategy = MultilingualDistillationStrategy(
|
|
31
|
+
... source_languages=["en"],
|
|
32
|
+
... target_languages=["de", "es", "fr"]
|
|
33
|
+
... )
|
|
34
|
+
>>> strategy.prepare(teacher, student, train_data)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# Source languages (teacher understands these)
|
|
38
|
+
source_languages: list[str] = field(default_factory=lambda: ["en"])
|
|
39
|
+
|
|
40
|
+
# Target languages (student should learn these)
|
|
41
|
+
target_languages: list[str] = field(default_factory=list)
|
|
42
|
+
|
|
43
|
+
# Parallel sentence datasets
|
|
44
|
+
parallel_datasets: list[str] = field(
|
|
45
|
+
default_factory=lambda: [
|
|
46
|
+
"sentence-transformers/parallel-sentences-talks",
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Maximum sentences per language pair
|
|
51
|
+
max_sentences_per_language: int = 500000
|
|
52
|
+
|
|
53
|
+
# Loss function
|
|
54
|
+
loss_fn: str = "mse"
|
|
55
|
+
|
|
56
|
+
def __post_init__(self):
|
|
57
|
+
self.teacher = None
|
|
58
|
+
self.student = None
|
|
59
|
+
self._train_datasets = None
|
|
60
|
+
self._eval_datasets = None
|
|
61
|
+
|
|
62
|
+
def prepare(
|
|
63
|
+
self,
|
|
64
|
+
teacher: SentenceTransformer,
|
|
65
|
+
student: SentenceTransformer,
|
|
66
|
+
batch_size: int = 64,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Prepare for multilingual distillation.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
teacher: Teacher model (monolingual).
|
|
73
|
+
student: Student model (multilingual).
|
|
74
|
+
batch_size: Batch size for encoding.
|
|
75
|
+
"""
|
|
76
|
+
self.teacher = teacher
|
|
77
|
+
self.student = student
|
|
78
|
+
self.batch_size = batch_size
|
|
79
|
+
|
|
80
|
+
def load_parallel_data(self) -> tuple[DatasetDict, DatasetDict]:
|
|
81
|
+
"""
|
|
82
|
+
Load parallel sentence data for training.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Tuple of (train_datasets, eval_datasets) as DatasetDict.
|
|
86
|
+
"""
|
|
87
|
+
train_datasets = DatasetDict()
|
|
88
|
+
eval_datasets = DatasetDict()
|
|
89
|
+
|
|
90
|
+
for source_lang in self.source_languages:
|
|
91
|
+
for target_lang in self.target_languages:
|
|
92
|
+
subset = f"{source_lang}-{target_lang}"
|
|
93
|
+
|
|
94
|
+
for dataset_name in self.parallel_datasets:
|
|
95
|
+
try:
|
|
96
|
+
train_data = load_dataset(dataset_name, subset, split="train")
|
|
97
|
+
|
|
98
|
+
# Limit size
|
|
99
|
+
if len(train_data) > self.max_sentences_per_language:
|
|
100
|
+
train_data = train_data.select(range(self.max_sentences_per_language))
|
|
101
|
+
|
|
102
|
+
# Try to get eval split
|
|
103
|
+
try:
|
|
104
|
+
eval_data = load_dataset(dataset_name, subset, split="dev")
|
|
105
|
+
if len(eval_data) > 1000:
|
|
106
|
+
eval_data = eval_data.select(range(1000))
|
|
107
|
+
except Exception:
|
|
108
|
+
# Split from train
|
|
109
|
+
split_data = train_data.train_test_split(test_size=1000, shuffle=True)
|
|
110
|
+
train_data = split_data["train"]
|
|
111
|
+
eval_data = split_data["test"]
|
|
112
|
+
|
|
113
|
+
train_datasets[subset] = train_data
|
|
114
|
+
eval_datasets[subset] = eval_data
|
|
115
|
+
|
|
116
|
+
logger.info(f"Loaded {len(train_data)} training samples for {subset}")
|
|
117
|
+
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.warning(f"Could not load {dataset_name}/{subset}: {e}")
|
|
120
|
+
|
|
121
|
+
self._train_datasets = train_datasets
|
|
122
|
+
self._eval_datasets = eval_datasets
|
|
123
|
+
|
|
124
|
+
return train_datasets, eval_datasets
|
|
125
|
+
|
|
126
|
+
def prepare_dataset(
|
|
127
|
+
self,
|
|
128
|
+
dataset: Any,
|
|
129
|
+
source_col: str = "english",
|
|
130
|
+
target_col: str = "non_english",
|
|
131
|
+
) -> Any:
|
|
132
|
+
"""
|
|
133
|
+
Prepare dataset with teacher embeddings.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
dataset: Dataset with parallel sentences.
|
|
137
|
+
source_col: Column name for source language sentences.
|
|
138
|
+
target_col: Column name for target language sentences.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dataset with teacher embeddings added.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def add_teacher_embeddings(batch):
|
|
145
|
+
source_sentences = batch[source_col]
|
|
146
|
+
with torch.no_grad():
|
|
147
|
+
embeddings = self.teacher.encode(
|
|
148
|
+
source_sentences,
|
|
149
|
+
batch_size=self.batch_size,
|
|
150
|
+
show_progress_bar=False,
|
|
151
|
+
convert_to_numpy=True,
|
|
152
|
+
)
|
|
153
|
+
return {"label": embeddings.tolist()}
|
|
154
|
+
|
|
155
|
+
return dataset.map(add_teacher_embeddings, batched=True, batch_size=10000)
|
|
156
|
+
|
|
157
|
+
def get_loss_function(self):
|
|
158
|
+
"""Get the loss function."""
|
|
159
|
+
if self.loss_fn == "mse":
|
|
160
|
+
return DistillationLosses.mse_loss
|
|
161
|
+
elif self.loss_fn == "cosine":
|
|
162
|
+
return DistillationLosses.cosine_loss
|
|
163
|
+
else:
|
|
164
|
+
return DistillationLosses.mse_loss
|
|
165
|
+
|
|
166
|
+
def compute_loss(
|
|
167
|
+
self,
|
|
168
|
+
student_output: torch.Tensor,
|
|
169
|
+
teacher_output: torch.Tensor,
|
|
170
|
+
batch: dict,
|
|
171
|
+
) -> torch.Tensor:
|
|
172
|
+
"""
|
|
173
|
+
Compute multilingual distillation loss.
|
|
174
|
+
|
|
175
|
+
The loss encourages:
|
|
176
|
+
1. Student source embeddings to match teacher source embeddings
|
|
177
|
+
2. Student target embeddings to match teacher source embeddings
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
student_output: Student embeddings (for source or target text).
|
|
181
|
+
teacher_output: Teacher embeddings (for source text).
|
|
182
|
+
batch: Batch containing source and target sentences.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Loss value.
|
|
186
|
+
"""
|
|
187
|
+
loss_fn = self.get_loss_function()
|
|
188
|
+
return loss_fn(student_output, teacher_output)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class MultilingualDistilTrainer:
|
|
192
|
+
"""
|
|
193
|
+
Trainer for multilingual distillation.
|
|
194
|
+
|
|
195
|
+
Extends a monolingual teacher to multiple languages via parallel sentence training.
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> trainer = MultilingualDistilTrainer(
|
|
199
|
+
... teacher_model="paraphrase-distilroberta-base-v2",
|
|
200
|
+
... student_model="xlm-roberta-base",
|
|
201
|
+
... )
|
|
202
|
+
>>> trainer.add_languages(["de", "es", "fr"])
|
|
203
|
+
>>> trainer.train()
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
teacher_model: str | SentenceTransformer,
|
|
209
|
+
student_model: str | SentenceTransformer,
|
|
210
|
+
source_languages: list[str] | None = None,
|
|
211
|
+
target_languages: list[str] | None = None,
|
|
212
|
+
output_dir: str = "./multilingual_model",
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
Initialize the trainer.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
teacher_model: Teacher model (monolingual, e.g., English).
|
|
219
|
+
student_model: Student model (multilingual base).
|
|
220
|
+
source_languages: Languages the teacher understands.
|
|
221
|
+
target_languages: Languages to extend to.
|
|
222
|
+
output_dir: Output directory for the trained model.
|
|
223
|
+
"""
|
|
224
|
+
if isinstance(teacher_model, str):
|
|
225
|
+
self.teacher = SentenceTransformer(teacher_model)
|
|
226
|
+
else:
|
|
227
|
+
self.teacher = teacher_model
|
|
228
|
+
|
|
229
|
+
if isinstance(student_model, str):
|
|
230
|
+
self.student = SentenceTransformer(student_model)
|
|
231
|
+
else:
|
|
232
|
+
self.student = student_model
|
|
233
|
+
|
|
234
|
+
self.source_languages = source_languages or ["en"]
|
|
235
|
+
self.target_languages = target_languages or []
|
|
236
|
+
self.output_dir = output_dir
|
|
237
|
+
|
|
238
|
+
self.strategy = MultilingualDistillationStrategy(
|
|
239
|
+
source_languages=self.source_languages,
|
|
240
|
+
target_languages=self.target_languages,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def add_languages(self, languages: list[str]) -> None:
|
|
244
|
+
"""Add target languages to extend to."""
|
|
245
|
+
self.target_languages.extend(languages)
|
|
246
|
+
self.strategy.target_languages = self.target_languages
|
|
247
|
+
|
|
248
|
+
def train(
|
|
249
|
+
self,
|
|
250
|
+
num_epochs: int = 5,
|
|
251
|
+
batch_size: int = 64,
|
|
252
|
+
learning_rate: float = 2e-5,
|
|
253
|
+
) -> None:
|
|
254
|
+
"""
|
|
255
|
+
Train the multilingual model.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
num_epochs: Number of training epochs.
|
|
259
|
+
batch_size: Training batch size.
|
|
260
|
+
learning_rate: Learning rate.
|
|
261
|
+
"""
|
|
262
|
+
logger.info("Starting multilingual distillation training...")
|
|
263
|
+
|
|
264
|
+
# Prepare strategy
|
|
265
|
+
self.strategy.prepare(self.teacher, self.student, batch_size)
|
|
266
|
+
|
|
267
|
+
# Load data
|
|
268
|
+
train_datasets, eval_datasets = self.strategy.load_parallel_data()
|
|
269
|
+
|
|
270
|
+
# Prepare datasets with teacher embeddings
|
|
271
|
+
for subset in train_datasets:
|
|
272
|
+
logger.info(f"Preparing {subset} with teacher embeddings...")
|
|
273
|
+
train_datasets[subset] = self.strategy.prepare_dataset(train_datasets[subset])
|
|
274
|
+
|
|
275
|
+
logger.info("Training prepared. Use DistilTrainer for actual training.")
|
|
276
|
+
|
|
277
|
+
# Save the strategy and datasets for use with DistilTrainer
|
|
278
|
+
self._train_datasets = train_datasets
|
|
279
|
+
self._eval_datasets = eval_datasets
|
|
280
|
+
|
|
281
|
+
def save_model(self, path: str | None = None) -> None:
|
|
282
|
+
"""Save the trained model."""
|
|
283
|
+
save_path = path or self.output_dir
|
|
284
|
+
self.student.save(save_path)
|
|
285
|
+
logger.info(f"Model saved to {save_path}")
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Distillation strategies."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
|
|
12
|
+
from sentence_transformers import SentenceTransformer
|
|
13
|
+
|
|
14
|
+
from distil_trainer.distillation.losses import DistillationLosses
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DistillationStrategy(ABC):
|
|
18
|
+
"""Base class for distillation strategies."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def prepare(self, teacher: Any, student: Any, data: Any) -> None:
|
|
22
|
+
"""Prepare for distillation."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_loss_function(self) -> Callable:
|
|
27
|
+
"""Get the loss function for this strategy."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def compute_loss(
|
|
32
|
+
self,
|
|
33
|
+
student_output: torch.Tensor,
|
|
34
|
+
teacher_output: torch.Tensor,
|
|
35
|
+
batch: dict,
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
"""Compute the distillation loss."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class EmbeddingDistillationStrategy(DistillationStrategy):
|
|
43
|
+
"""
|
|
44
|
+
Strategy for distilling embedding models.
|
|
45
|
+
|
|
46
|
+
The student learns to produce embeddings similar to the teacher.
|
|
47
|
+
Uses MSE or cosine loss between student and teacher embeddings.
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> strategy = EmbeddingDistillationStrategy(loss_fn="mse")
|
|
51
|
+
>>> strategy.prepare(teacher, student, train_data)
|
|
52
|
+
>>> loss_fn = strategy.get_loss_function()
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# Loss function
|
|
56
|
+
loss_fn: str = "mse" # "mse", "cosine", "combined"
|
|
57
|
+
|
|
58
|
+
# Whether to precompute teacher embeddings
|
|
59
|
+
precompute_embeddings: bool = True
|
|
60
|
+
|
|
61
|
+
# PCA for dimension reduction
|
|
62
|
+
use_pca: bool = True
|
|
63
|
+
pca_components: int | None = None
|
|
64
|
+
|
|
65
|
+
# Dataset requirements
|
|
66
|
+
required_columns: list[str] = field(default_factory=lambda: ["sentence"])
|
|
67
|
+
|
|
68
|
+
def prepare(self, teacher: Any, student: Any, data: Any) -> None:
|
|
69
|
+
"""Prepare for embedding distillation."""
|
|
70
|
+
self.teacher = teacher
|
|
71
|
+
self.student = student
|
|
72
|
+
|
|
73
|
+
def get_loss_function(self) -> Callable:
|
|
74
|
+
"""Get the loss function."""
|
|
75
|
+
if self.loss_fn == "mse":
|
|
76
|
+
return DistillationLosses.mse_loss
|
|
77
|
+
elif self.loss_fn == "cosine":
|
|
78
|
+
return DistillationLosses.cosine_loss
|
|
79
|
+
elif self.loss_fn == "combined":
|
|
80
|
+
def combined_loss(student_out, teacher_out):
|
|
81
|
+
mse = DistillationLosses.mse_loss(student_out, teacher_out)
|
|
82
|
+
cosine = DistillationLosses.cosine_loss(student_out, teacher_out)
|
|
83
|
+
return 0.5 * mse + 0.5 * cosine
|
|
84
|
+
return combined_loss
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Unknown loss function: {self.loss_fn}")
|
|
87
|
+
|
|
88
|
+
def compute_loss(
|
|
89
|
+
self,
|
|
90
|
+
student_output: torch.Tensor,
|
|
91
|
+
teacher_output: torch.Tensor,
|
|
92
|
+
batch: dict,
|
|
93
|
+
) -> torch.Tensor:
|
|
94
|
+
"""Compute embedding distillation loss."""
|
|
95
|
+
loss_fn = self.get_loss_function()
|
|
96
|
+
return loss_fn(student_output, teacher_output)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@dataclass
|
|
100
|
+
class LogitDistillationStrategy(DistillationStrategy):
|
|
101
|
+
"""
|
|
102
|
+
Strategy for distilling via logit matching.
|
|
103
|
+
|
|
104
|
+
Uses KL divergence between softmax distributions of student and teacher.
|
|
105
|
+
|
|
106
|
+
Example:
|
|
107
|
+
>>> strategy = LogitDistillationStrategy(temperature=2.0)
|
|
108
|
+
>>> loss = strategy.compute_loss(student_logits, teacher_logits, batch)
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
# Temperature for softmax
|
|
112
|
+
temperature: float = 1.0
|
|
113
|
+
|
|
114
|
+
# Whether to also include hard label loss
|
|
115
|
+
alpha: float = 0.5 # Weight for soft targets vs hard targets
|
|
116
|
+
|
|
117
|
+
def prepare(self, teacher: Any, student: Any, data: Any) -> None:
|
|
118
|
+
"""Prepare for logit distillation."""
|
|
119
|
+
self.teacher = teacher
|
|
120
|
+
self.student = student
|
|
121
|
+
|
|
122
|
+
def get_loss_function(self) -> Callable:
|
|
123
|
+
"""Get the loss function."""
|
|
124
|
+
def loss_fn(student_logits, teacher_logits):
|
|
125
|
+
return DistillationLosses.kl_divergence_loss(
|
|
126
|
+
student_logits, teacher_logits, self.temperature
|
|
127
|
+
)
|
|
128
|
+
return loss_fn
|
|
129
|
+
|
|
130
|
+
def compute_loss(
|
|
131
|
+
self,
|
|
132
|
+
student_output: torch.Tensor,
|
|
133
|
+
teacher_output: torch.Tensor,
|
|
134
|
+
batch: dict,
|
|
135
|
+
) -> torch.Tensor:
|
|
136
|
+
"""Compute logit distillation loss."""
|
|
137
|
+
soft_loss = DistillationLosses.kl_divergence_loss(
|
|
138
|
+
student_output, teacher_output, self.temperature
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# If hard labels available, combine with cross-entropy
|
|
142
|
+
if "labels" in batch and self.alpha < 1.0:
|
|
143
|
+
import torch.nn.functional as F
|
|
144
|
+
hard_loss = F.cross_entropy(student_output, batch["labels"])
|
|
145
|
+
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
|
|
146
|
+
|
|
147
|
+
return soft_loss
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class RankingDistillationStrategy(DistillationStrategy):
|
|
152
|
+
"""
|
|
153
|
+
Strategy for distilling retrieval models with ranking loss.
|
|
154
|
+
|
|
155
|
+
Example:
|
|
156
|
+
>>> strategy = RankingDistillationStrategy(in_batch_negatives=True)
|
|
157
|
+
>>> loss = strategy.compute_loss(query_emb, pos_emb, batch)
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
# Whether to use in-batch negatives
|
|
161
|
+
in_batch_negatives: bool = True
|
|
162
|
+
|
|
163
|
+
# Number of hard negatives per sample
|
|
164
|
+
hard_negatives: int = 5
|
|
165
|
+
|
|
166
|
+
# Margin for triplet loss
|
|
167
|
+
margin: float = 0.5
|
|
168
|
+
|
|
169
|
+
def prepare(self, teacher: Any, student: Any, data: Any) -> None:
|
|
170
|
+
"""Prepare for ranking distillation."""
|
|
171
|
+
self.teacher = teacher
|
|
172
|
+
self.student = student
|
|
173
|
+
|
|
174
|
+
def get_loss_function(self) -> Callable:
|
|
175
|
+
"""Get the loss function."""
|
|
176
|
+
def loss_fn(query_emb, pos_emb, neg_emb=None):
|
|
177
|
+
return DistillationLosses.ranking_loss(
|
|
178
|
+
query_emb, pos_emb, neg_emb,
|
|
179
|
+
margin=self.margin,
|
|
180
|
+
in_batch_negatives=self.in_batch_negatives,
|
|
181
|
+
)
|
|
182
|
+
return loss_fn
|
|
183
|
+
|
|
184
|
+
def compute_loss(
|
|
185
|
+
self,
|
|
186
|
+
student_output: torch.Tensor,
|
|
187
|
+
teacher_output: torch.Tensor,
|
|
188
|
+
batch: dict,
|
|
189
|
+
) -> torch.Tensor:
|
|
190
|
+
"""Compute ranking distillation loss."""
|
|
191
|
+
# For ranking, student_output contains query embeddings
|
|
192
|
+
# teacher_output contains positive embeddings (from teacher)
|
|
193
|
+
query_embeddings = student_output
|
|
194
|
+
|
|
195
|
+
# Get positive embeddings from batch or compute
|
|
196
|
+
if "positive_embeddings" in batch:
|
|
197
|
+
positive_embeddings = batch["positive_embeddings"]
|
|
198
|
+
else:
|
|
199
|
+
# Use teacher embeddings as targets
|
|
200
|
+
positive_embeddings = teacher_output
|
|
201
|
+
|
|
202
|
+
# Get negative embeddings if available
|
|
203
|
+
negative_embeddings = batch.get("negative_embeddings", None)
|
|
204
|
+
|
|
205
|
+
return DistillationLosses.ranking_loss(
|
|
206
|
+
query_embeddings,
|
|
207
|
+
positive_embeddings,
|
|
208
|
+
negative_embeddings,
|
|
209
|
+
margin=self.margin,
|
|
210
|
+
in_batch_negatives=self.in_batch_negatives,
|
|
211
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Evaluation module for distillation training."""
|
|
2
|
+
|
|
3
|
+
from distil_trainer.evaluation.evaluators import (
|
|
4
|
+
EmbeddingSimilarityEvaluator,
|
|
5
|
+
MSEEvaluator,
|
|
6
|
+
TranslationEvaluator,
|
|
7
|
+
SequentialEvaluator,
|
|
8
|
+
)
|
|
9
|
+
from distil_trainer.evaluation.metrics import DistillationMetrics
|
|
10
|
+
from distil_trainer.evaluation.benchmarks import BenchmarkRunner
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"EmbeddingSimilarityEvaluator",
|
|
14
|
+
"MSEEvaluator",
|
|
15
|
+
"TranslationEvaluator",
|
|
16
|
+
"SequentialEvaluator",
|
|
17
|
+
"DistillationMetrics",
|
|
18
|
+
"BenchmarkRunner",
|
|
19
|
+
]
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Benchmark runners for model evaluation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BenchmarkRunner:
|
|
14
|
+
"""Run standard benchmarks for evaluation."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, model: SentenceTransformer | Any):
|
|
17
|
+
"""
|
|
18
|
+
Initialize the benchmark runner.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
model: Model to evaluate.
|
|
22
|
+
"""
|
|
23
|
+
self.model = model
|
|
24
|
+
|
|
25
|
+
def run_stsb(self, split: str = "test") -> dict[str, float]:
|
|
26
|
+
"""Run STS Benchmark evaluation."""
|
|
27
|
+
from datasets import load_dataset
|
|
28
|
+
from distil_trainer.evaluation.evaluators import EmbeddingSimilarityEvaluator
|
|
29
|
+
|
|
30
|
+
dataset = load_dataset("sentence-transformers/stsb", split=split)
|
|
31
|
+
|
|
32
|
+
evaluator = EmbeddingSimilarityEvaluator(
|
|
33
|
+
sentences1=dataset["sentence1"],
|
|
34
|
+
sentences2=dataset["sentence2"],
|
|
35
|
+
scores=dataset["score"],
|
|
36
|
+
name=f"stsb-{split}",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
return evaluator(self.model)
|
|
40
|
+
|
|
41
|
+
def run_mteb(
|
|
42
|
+
self,
|
|
43
|
+
tasks: list[str] | None = None,
|
|
44
|
+
languages: list[str] | None = None,
|
|
45
|
+
) -> dict[str, float]:
|
|
46
|
+
"""
|
|
47
|
+
Run MTEB benchmark suite.
|
|
48
|
+
|
|
49
|
+
Requires mteb package: pip install mteb
|
|
50
|
+
"""
|
|
51
|
+
try:
|
|
52
|
+
from mteb import MTEB
|
|
53
|
+
except ImportError:
|
|
54
|
+
logger.warning("MTEB not installed. Install with: pip install mteb")
|
|
55
|
+
return {}
|
|
56
|
+
|
|
57
|
+
if tasks is None:
|
|
58
|
+
tasks = ["STS12", "STS13", "STS14", "STS15", "STS16"]
|
|
59
|
+
|
|
60
|
+
evaluation = MTEB(tasks=tasks)
|
|
61
|
+
results = evaluation.run(self.model, output_folder=None)
|
|
62
|
+
|
|
63
|
+
# Flatten results
|
|
64
|
+
flat_results = {}
|
|
65
|
+
for task_result in results:
|
|
66
|
+
task_name = task_result.task_name
|
|
67
|
+
for metric, value in task_result.scores.items():
|
|
68
|
+
flat_results[f"{task_name}_{metric}"] = value
|
|
69
|
+
|
|
70
|
+
return flat_results
|
|
71
|
+
|
|
72
|
+
def run_all(self) -> dict[str, dict[str, float]]:
|
|
73
|
+
"""Run all available benchmarks."""
|
|
74
|
+
results = {}
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
results["stsb"] = self.run_stsb()
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.warning(f"STSB evaluation failed: {e}")
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
results["mteb"] = self.run_mteb()
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.warning(f"MTEB evaluation failed: {e}")
|
|
85
|
+
|
|
86
|
+
return results
|