distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,343 @@
1
+ """Evaluators for distillation quality assessment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, Callable
7
+
8
+ import numpy as np
9
+ import torch
10
+ from scipy.stats import pearsonr, spearmanr
11
+ from tqdm import tqdm
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EmbeddingSimilarityEvaluator:
19
+ """
20
+ Evaluate sentence embedding quality on STS tasks.
21
+
22
+ Computes correlation between model similarity and human scores.
23
+
24
+ Example:
25
+ >>> evaluator = EmbeddingSimilarityEvaluator(
26
+ ... sentences1=["Hello", "Hi"],
27
+ ... sentences2=["World", "Earth"],
28
+ ... scores=[0.5, 0.8],
29
+ ... )
30
+ >>> results = evaluator(model)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ sentences1: list[str],
36
+ sentences2: list[str],
37
+ scores: list[float],
38
+ batch_size: int = 32,
39
+ name: str = "sts-eval",
40
+ main_similarity: str = "cosine",
41
+ show_progress_bar: bool = True,
42
+ ):
43
+ """
44
+ Initialize the evaluator.
45
+
46
+ Args:
47
+ sentences1: First set of sentences.
48
+ sentences2: Second set of sentences.
49
+ scores: Ground truth similarity scores.
50
+ batch_size: Batch size for encoding.
51
+ name: Name for logging.
52
+ main_similarity: Similarity function to use.
53
+ show_progress_bar: Whether to show progress bar.
54
+ """
55
+ self.sentences1 = sentences1
56
+ self.sentences2 = sentences2
57
+ self.scores = scores
58
+ self.batch_size = batch_size
59
+ self.name = name
60
+ self.main_similarity = main_similarity
61
+ self.show_progress_bar = show_progress_bar
62
+
63
+ def __call__(self, model: SentenceTransformer) -> dict[str, float]:
64
+ """
65
+ Evaluate the model.
66
+
67
+ Args:
68
+ model: Model to evaluate.
69
+
70
+ Returns:
71
+ Dictionary with correlation scores.
72
+ """
73
+ logger.info(f"Running {self.name} evaluation")
74
+
75
+ # Encode sentences
76
+ embeddings1 = model.encode(
77
+ self.sentences1,
78
+ batch_size=self.batch_size,
79
+ show_progress_bar=self.show_progress_bar,
80
+ convert_to_tensor=True,
81
+ )
82
+ embeddings2 = model.encode(
83
+ self.sentences2,
84
+ batch_size=self.batch_size,
85
+ show_progress_bar=self.show_progress_bar,
86
+ convert_to_tensor=True,
87
+ )
88
+
89
+ # Compute similarities
90
+ if self.main_similarity == "cosine":
91
+ similarities = torch.nn.functional.cosine_similarity(
92
+ embeddings1, embeddings2
93
+ ).cpu().numpy()
94
+ else:
95
+ # Euclidean distance (convert to similarity)
96
+ distances = torch.pairwise_distance(embeddings1, embeddings2).cpu().numpy()
97
+ similarities = 1 / (1 + distances)
98
+
99
+ # Compute correlations
100
+ pearson = pearsonr(similarities, self.scores)[0]
101
+ spearman = spearmanr(similarities, self.scores)[0]
102
+
103
+ results = {
104
+ f"{self.name}_pearson_{self.main_similarity}": pearson,
105
+ f"{self.name}_spearman_{self.main_similarity}": spearman,
106
+ }
107
+
108
+ logger.info(f"{self.name}: Pearson={pearson:.4f}, Spearman={spearman:.4f}")
109
+
110
+ return results
111
+
112
+
113
+ class MSEEvaluator:
114
+ """
115
+ Evaluate MSE between student and teacher embeddings.
116
+
117
+ Example:
118
+ >>> evaluator = MSEEvaluator(
119
+ ... source_sentences=["Hello", "World"],
120
+ ... target_sentences=["Hello", "World"],
121
+ ... teacher_model=teacher,
122
+ ... )
123
+ >>> results = evaluator(student_model)
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ source_sentences: list[str],
129
+ target_sentences: list[str],
130
+ teacher_model: SentenceTransformer,
131
+ batch_size: int = 32,
132
+ name: str = "mse-eval",
133
+ show_progress_bar: bool = True,
134
+ ):
135
+ """
136
+ Initialize the evaluator.
137
+
138
+ Args:
139
+ source_sentences: Source sentences.
140
+ target_sentences: Target sentences (same as source for distillation).
141
+ teacher_model: Teacher model for comparison.
142
+ batch_size: Batch size for encoding.
143
+ name: Name for logging.
144
+ show_progress_bar: Whether to show progress bar.
145
+ """
146
+ self.source_sentences = source_sentences
147
+ self.target_sentences = target_sentences
148
+ self.teacher_model = teacher_model
149
+ self.batch_size = batch_size
150
+ self.name = name
151
+ self.show_progress_bar = show_progress_bar
152
+
153
+ # Precompute teacher embeddings
154
+ with torch.no_grad():
155
+ self.teacher_embeddings = teacher_model.encode(
156
+ source_sentences,
157
+ batch_size=batch_size,
158
+ show_progress_bar=show_progress_bar,
159
+ convert_to_tensor=True,
160
+ )
161
+
162
+ def __call__(self, model: SentenceTransformer) -> dict[str, float]:
163
+ """
164
+ Evaluate the model.
165
+
166
+ Args:
167
+ model: Student model to evaluate.
168
+
169
+ Returns:
170
+ Dictionary with MSE score.
171
+ """
172
+ logger.info(f"Running {self.name} evaluation")
173
+
174
+ # Encode with student
175
+ student_embeddings = model.encode(
176
+ self.target_sentences,
177
+ batch_size=self.batch_size,
178
+ show_progress_bar=self.show_progress_bar,
179
+ convert_to_tensor=True,
180
+ )
181
+
182
+ # Compute MSE
183
+ mse = torch.nn.functional.mse_loss(
184
+ student_embeddings, self.teacher_embeddings
185
+ ).item()
186
+
187
+ # Compute cosine similarity
188
+ cosine_sim = torch.nn.functional.cosine_similarity(
189
+ student_embeddings, self.teacher_embeddings
190
+ ).mean().item()
191
+
192
+ results = {
193
+ f"{self.name}_mse": mse,
194
+ f"{self.name}_cosine_similarity": cosine_sim,
195
+ }
196
+
197
+ logger.info(f"{self.name}: MSE={mse:.6f}, Cosine={cosine_sim:.4f}")
198
+
199
+ return results
200
+
201
+
202
+ class TranslationEvaluator:
203
+ """
204
+ Evaluate multilingual alignment via translation retrieval.
205
+
206
+ Checks if source[i] embedding is closest to target[i].
207
+
208
+ Example:
209
+ >>> evaluator = TranslationEvaluator(
210
+ ... source_sentences=["Hello", "World"],
211
+ ... target_sentences=["Hallo", "Welt"],
212
+ ... )
213
+ >>> results = evaluator(model)
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ source_sentences: list[str],
219
+ target_sentences: list[str],
220
+ batch_size: int = 32,
221
+ name: str = "translation-eval",
222
+ show_progress_bar: bool = True,
223
+ ):
224
+ """
225
+ Initialize the evaluator.
226
+
227
+ Args:
228
+ source_sentences: Source language sentences.
229
+ target_sentences: Target language sentences (parallel).
230
+ batch_size: Batch size for encoding.
231
+ name: Name for logging.
232
+ show_progress_bar: Whether to show progress bar.
233
+ """
234
+ self.source_sentences = source_sentences
235
+ self.target_sentences = target_sentences
236
+ self.batch_size = batch_size
237
+ self.name = name
238
+ self.show_progress_bar = show_progress_bar
239
+
240
+ def __call__(self, model: SentenceTransformer) -> dict[str, float]:
241
+ """
242
+ Evaluate the model.
243
+
244
+ Args:
245
+ model: Model to evaluate.
246
+
247
+ Returns:
248
+ Dictionary with retrieval accuracy scores.
249
+ """
250
+ logger.info(f"Running {self.name} evaluation")
251
+
252
+ # Encode both sets
253
+ source_embeddings = model.encode(
254
+ self.source_sentences,
255
+ batch_size=self.batch_size,
256
+ show_progress_bar=self.show_progress_bar,
257
+ convert_to_tensor=True,
258
+ )
259
+ target_embeddings = model.encode(
260
+ self.target_sentences,
261
+ batch_size=self.batch_size,
262
+ show_progress_bar=self.show_progress_bar,
263
+ convert_to_tensor=True,
264
+ )
265
+
266
+ # Normalize for cosine similarity
267
+ source_embeddings = torch.nn.functional.normalize(source_embeddings, p=2, dim=1)
268
+ target_embeddings = torch.nn.functional.normalize(target_embeddings, p=2, dim=1)
269
+
270
+ # Compute similarity matrix
271
+ similarity_matrix = torch.mm(source_embeddings, target_embeddings.t())
272
+
273
+ # Get rankings
274
+ correct_indices = torch.arange(len(self.source_sentences), device=similarity_matrix.device)
275
+ rankings = (similarity_matrix.argsort(dim=1, descending=True) == correct_indices.unsqueeze(1)).nonzero()[:, 1]
276
+
277
+ # Compute metrics
278
+ acc_at_1 = (rankings == 0).float().mean().item()
279
+ acc_at_5 = (rankings < 5).float().mean().item()
280
+ acc_at_10 = (rankings < 10).float().mean().item()
281
+ mrr = (1.0 / (rankings.float() + 1)).mean().item()
282
+
283
+ results = {
284
+ f"{self.name}_accuracy@1": acc_at_1,
285
+ f"{self.name}_accuracy@5": acc_at_5,
286
+ f"{self.name}_accuracy@10": acc_at_10,
287
+ f"{self.name}_mrr": mrr,
288
+ }
289
+
290
+ logger.info(f"{self.name}: Acc@1={acc_at_1:.4f}, MRR={mrr:.4f}")
291
+
292
+ return results
293
+
294
+
295
+ class SequentialEvaluator:
296
+ """
297
+ Run multiple evaluators sequentially.
298
+
299
+ Example:
300
+ >>> evaluator = SequentialEvaluator([evaluator1, evaluator2])
301
+ >>> results = evaluator(model)
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ evaluators: list[Any],
307
+ main_score_function: Callable[[dict], float] | None = None,
308
+ ):
309
+ """
310
+ Initialize the sequential evaluator.
311
+
312
+ Args:
313
+ evaluators: List of evaluators to run.
314
+ main_score_function: Function to compute main score from results.
315
+ """
316
+ self.evaluators = evaluators
317
+ self.main_score_function = main_score_function or (
318
+ lambda x: np.mean([v for v in x.values() if isinstance(v, (int, float))])
319
+ )
320
+
321
+ def __call__(self, model: SentenceTransformer) -> dict[str, float]:
322
+ """
323
+ Run all evaluators.
324
+
325
+ Args:
326
+ model: Model to evaluate.
327
+
328
+ Returns:
329
+ Combined dictionary of all evaluation results.
330
+ """
331
+ all_results = {}
332
+
333
+ for evaluator in self.evaluators:
334
+ try:
335
+ results = evaluator(model)
336
+ all_results.update(results)
337
+ except Exception as e:
338
+ logger.warning(f"Evaluator {evaluator} failed: {e}")
339
+
340
+ # Compute main score
341
+ all_results["main_score"] = self.main_score_function(all_results)
342
+
343
+ return all_results
@@ -0,0 +1,75 @@
1
+ """Metrics for evaluating distillation quality."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from transformers import PreTrainedModel
7
+
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+
11
+ class DistillationMetrics:
12
+ """Metrics for evaluating distillation quality."""
13
+
14
+ @staticmethod
15
+ def embedding_mse(
16
+ student_embeddings: torch.Tensor,
17
+ teacher_embeddings: torch.Tensor,
18
+ ) -> float:
19
+ """Mean squared error between embeddings."""
20
+ return torch.nn.functional.mse_loss(student_embeddings, teacher_embeddings).item()
21
+
22
+ @staticmethod
23
+ def embedding_cosine_similarity(
24
+ student_embeddings: torch.Tensor,
25
+ teacher_embeddings: torch.Tensor,
26
+ ) -> float:
27
+ """Average cosine similarity between embeddings."""
28
+ return torch.nn.functional.cosine_similarity(
29
+ student_embeddings, teacher_embeddings, dim=-1
30
+ ).mean().item()
31
+
32
+ @staticmethod
33
+ def compression_ratio(
34
+ student_model: SentenceTransformer | PreTrainedModel,
35
+ teacher_model: SentenceTransformer | PreTrainedModel,
36
+ ) -> float:
37
+ """Parameter count ratio (teacher / student)."""
38
+ student_params = sum(p.numel() for p in student_model.parameters())
39
+ teacher_params = sum(p.numel() for p in teacher_model.parameters())
40
+ return teacher_params / student_params
41
+
42
+ @staticmethod
43
+ def speedup_factor(
44
+ student_model: SentenceTransformer | PreTrainedModel,
45
+ teacher_model: SentenceTransformer | PreTrainedModel,
46
+ input_batch: dict | list[str],
47
+ num_runs: int = 10,
48
+ ) -> float:
49
+ """Inference speedup of student over teacher."""
50
+ import time
51
+
52
+ # Warm up
53
+ if isinstance(student_model, SentenceTransformer):
54
+ student_model.encode(input_batch[:1])
55
+ teacher_model.encode(input_batch[:1])
56
+
57
+ # Time student
58
+ start = time.time()
59
+ for _ in range(num_runs):
60
+ if isinstance(student_model, SentenceTransformer):
61
+ student_model.encode(input_batch)
62
+ else:
63
+ student_model(**input_batch)
64
+ student_time = time.time() - start
65
+
66
+ # Time teacher
67
+ start = time.time()
68
+ for _ in range(num_runs):
69
+ if isinstance(teacher_model, SentenceTransformer):
70
+ teacher_model.encode(input_batch)
71
+ else:
72
+ teacher_model(**input_batch)
73
+ teacher_time = time.time() - start
74
+
75
+ return teacher_time / student_time
@@ -0,0 +1,5 @@
1
+ """Models module for distil_trainer."""
2
+
3
+ from distil_trainer.models.layers import DenseProjection
4
+
5
+ __all__ = ["DenseProjection"]
@@ -0,0 +1,115 @@
1
+ """Custom layers for distillation models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class DenseProjection(nn.Module):
10
+ """
11
+ Dense projection layer for dimension reduction.
12
+
13
+ Used to project teacher embeddings to student dimension via PCA weights.
14
+
15
+ Example:
16
+ >>> projection = DenseProjection(in_features=768, out_features=256)
17
+ >>> reduced = projection(embeddings)
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ in_features: int,
23
+ out_features: int,
24
+ bias: bool = False,
25
+ weights: torch.Tensor | None = None,
26
+ ):
27
+ """
28
+ Initialize the projection layer.
29
+
30
+ Args:
31
+ in_features: Input dimension (teacher).
32
+ out_features: Output dimension (student).
33
+ bias: Whether to include bias term.
34
+ weights: Optional initial weights (e.g., from PCA).
35
+ """
36
+ super().__init__()
37
+
38
+ self.in_features = in_features
39
+ self.out_features = out_features
40
+
41
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
42
+
43
+ if weights is not None:
44
+ with torch.no_grad():
45
+ self.linear.weight.copy_(weights)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ """Apply projection."""
49
+ return self.linear(x)
50
+
51
+
52
+ class PoolingLayer(nn.Module):
53
+ """
54
+ Pooling layer for converting token embeddings to sentence embeddings.
55
+
56
+ Supports multiple pooling strategies.
57
+
58
+ Example:
59
+ >>> pooler = PoolingLayer(pooling_mode="mean")
60
+ >>> sentence_embedding = pooler(token_embeddings, attention_mask)
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ pooling_mode: str = "mean",
66
+ ):
67
+ """
68
+ Initialize the pooling layer.
69
+
70
+ Args:
71
+ pooling_mode: One of "mean", "cls", "max", "weighted_mean".
72
+ """
73
+ super().__init__()
74
+ self.pooling_mode = pooling_mode
75
+
76
+ def forward(
77
+ self,
78
+ token_embeddings: torch.Tensor,
79
+ attention_mask: torch.Tensor,
80
+ ) -> torch.Tensor:
81
+ """
82
+ Apply pooling to token embeddings.
83
+
84
+ Args:
85
+ token_embeddings: Token embeddings [batch, seq_len, dim].
86
+ attention_mask: Attention mask [batch, seq_len].
87
+
88
+ Returns:
89
+ Pooled embeddings [batch, dim].
90
+ """
91
+ if self.pooling_mode == "cls":
92
+ return token_embeddings[:, 0]
93
+
94
+ elif self.pooling_mode == "max":
95
+ # Mask out padding tokens
96
+ token_embeddings = token_embeddings.masked_fill(
97
+ ~attention_mask.unsqueeze(-1).bool(), float("-inf")
98
+ )
99
+ return token_embeddings.max(dim=1).values
100
+
101
+ elif self.pooling_mode == "weighted_mean":
102
+ # Position-weighted mean
103
+ weights = torch.arange(
104
+ 1, token_embeddings.size(1) + 1,
105
+ device=token_embeddings.device,
106
+ dtype=token_embeddings.dtype,
107
+ )
108
+ weights = weights.unsqueeze(0).unsqueeze(-1) * attention_mask.unsqueeze(-1)
109
+ return (token_embeddings * weights).sum(dim=1) / weights.sum(dim=1)
110
+
111
+ else: # mean pooling
112
+ input_mask_expanded = attention_mask.unsqueeze(-1).float()
113
+ sum_embeddings = (token_embeddings * input_mask_expanded).sum(dim=1)
114
+ sum_mask = input_mask_expanded.sum(dim=1).clamp(min=1e-9)
115
+ return sum_embeddings / sum_mask
@@ -0,0 +1,13 @@
1
+ """Pruning module for model compression."""
2
+
3
+ from distil_trainer.pruning.depth_pruning import DepthPruner
4
+ from distil_trainer.pruning.width_pruning import WidthPruner
5
+ from distil_trainer.pruning.combined_pruning import CombinedPruner
6
+ from distil_trainer.pruning.importance import ImportanceEstimator
7
+
8
+ __all__ = [
9
+ "DepthPruner",
10
+ "WidthPruner",
11
+ "CombinedPruner",
12
+ "ImportanceEstimator",
13
+ ]
@@ -0,0 +1,122 @@
1
+ """Combined depth and width pruning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Literal
7
+
8
+ from transformers import PreTrainedModel
9
+
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ from distil_trainer.core.config import CombinedPruningConfig, LayerReductionConfig, WidthPruningConfig
13
+ from distil_trainer.pruning.depth_pruning import DepthPruner
14
+ from distil_trainer.pruning.width_pruning import WidthPruner
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CombinedPruner:
20
+ """
21
+ Combines depth and width pruning for maximum compression.
22
+
23
+ Example:
24
+ >>> config = CombinedPruningConfig(
25
+ ... depth_config=LayerReductionConfig(num_layers_to_keep=8),
26
+ ... width_config=WidthPruningConfig(target_hidden_size=3072),
27
+ ... pruning_order="depth_first"
28
+ ... )
29
+ >>> pruner = CombinedPruner(model)
30
+ >>> pruned_model = pruner.prune(config)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model: SentenceTransformer | PreTrainedModel,
36
+ calibration_data: list[str] | None = None,
37
+ ):
38
+ """
39
+ Initialize the CombinedPruner.
40
+
41
+ Args:
42
+ model: The model to prune.
43
+ calibration_data: Optional calibration data for importance estimation.
44
+ """
45
+ self.model = model
46
+ self.calibration_data = calibration_data
47
+
48
+ def prune(
49
+ self,
50
+ config: CombinedPruningConfig,
51
+ calibration_data: list[str] | None = None,
52
+ ) -> SentenceTransformer | PreTrainedModel:
53
+ """
54
+ Apply combined pruning based on configuration.
55
+
56
+ Args:
57
+ config: Combined pruning configuration.
58
+ calibration_data: Optional calibration data.
59
+
60
+ Returns:
61
+ New model with both depth and width reduced.
62
+ """
63
+ data = calibration_data or self.calibration_data
64
+ model = self.model
65
+
66
+ if config.pruning_order == "depth_first":
67
+ model = self._apply_depth_pruning(model, config.depth_config, data)
68
+ model = self._apply_width_pruning(model, config.width_config, data)
69
+ elif config.pruning_order == "width_first":
70
+ model = self._apply_width_pruning(model, config.width_config, data)
71
+ model = self._apply_depth_pruning(model, config.depth_config, data)
72
+ elif config.pruning_order == "interleaved":
73
+ # Apply in iterations, alternating
74
+ for i in range(config.num_iterations):
75
+ logger.info(f"Pruning iteration {i + 1}/{config.num_iterations}")
76
+ if i % 2 == 0:
77
+ model = self._apply_depth_pruning(model, config.depth_config, data)
78
+ else:
79
+ model = self._apply_width_pruning(model, config.width_config, data)
80
+ else:
81
+ raise ValueError(f"Unknown pruning order: {config.pruning_order}")
82
+
83
+ # Log final statistics
84
+ original_params = sum(p.numel() for p in self.model.parameters())
85
+ final_params = sum(p.numel() for p in model.parameters())
86
+ reduction = 1 - (final_params / original_params)
87
+ logger.info(f"Total parameter reduction: {original_params:,} -> {final_params:,} ({reduction:.1%})")
88
+
89
+ return model
90
+
91
+ def _apply_depth_pruning(
92
+ self,
93
+ model: SentenceTransformer | PreTrainedModel,
94
+ config: LayerReductionConfig | None,
95
+ calibration_data: list[str] | None,
96
+ ) -> SentenceTransformer | PreTrainedModel:
97
+ """Apply depth pruning if configured."""
98
+ if config is None:
99
+ return model
100
+
101
+ logger.info("Applying depth pruning...")
102
+ pruner = DepthPruner(model, calibration_data)
103
+ return pruner.prune(
104
+ layers_to_keep=config.layers_to_keep,
105
+ num_layers_to_keep=config.num_layers_to_keep,
106
+ layers_to_drop=config.layers_to_drop,
107
+ layer_selection=config.layer_selection,
108
+ )
109
+
110
+ def _apply_width_pruning(
111
+ self,
112
+ model: SentenceTransformer | PreTrainedModel,
113
+ config: WidthPruningConfig | None,
114
+ calibration_data: list[str] | None,
115
+ ) -> SentenceTransformer | PreTrainedModel:
116
+ """Apply width pruning if configured."""
117
+ if config is None:
118
+ return model
119
+
120
+ logger.info("Applying width pruning...")
121
+ pruner = WidthPruner(model, calibration_data)
122
+ return pruner.prune(config, calibration_data)