distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
"""Evaluators for distillation quality assessment."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
from scipy.stats import pearsonr, spearmanr
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmbeddingSimilarityEvaluator:
|
|
19
|
+
"""
|
|
20
|
+
Evaluate sentence embedding quality on STS tasks.
|
|
21
|
+
|
|
22
|
+
Computes correlation between model similarity and human scores.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
>>> evaluator = EmbeddingSimilarityEvaluator(
|
|
26
|
+
... sentences1=["Hello", "Hi"],
|
|
27
|
+
... sentences2=["World", "Earth"],
|
|
28
|
+
... scores=[0.5, 0.8],
|
|
29
|
+
... )
|
|
30
|
+
>>> results = evaluator(model)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
sentences1: list[str],
|
|
36
|
+
sentences2: list[str],
|
|
37
|
+
scores: list[float],
|
|
38
|
+
batch_size: int = 32,
|
|
39
|
+
name: str = "sts-eval",
|
|
40
|
+
main_similarity: str = "cosine",
|
|
41
|
+
show_progress_bar: bool = True,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the evaluator.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
sentences1: First set of sentences.
|
|
48
|
+
sentences2: Second set of sentences.
|
|
49
|
+
scores: Ground truth similarity scores.
|
|
50
|
+
batch_size: Batch size for encoding.
|
|
51
|
+
name: Name for logging.
|
|
52
|
+
main_similarity: Similarity function to use.
|
|
53
|
+
show_progress_bar: Whether to show progress bar.
|
|
54
|
+
"""
|
|
55
|
+
self.sentences1 = sentences1
|
|
56
|
+
self.sentences2 = sentences2
|
|
57
|
+
self.scores = scores
|
|
58
|
+
self.batch_size = batch_size
|
|
59
|
+
self.name = name
|
|
60
|
+
self.main_similarity = main_similarity
|
|
61
|
+
self.show_progress_bar = show_progress_bar
|
|
62
|
+
|
|
63
|
+
def __call__(self, model: SentenceTransformer) -> dict[str, float]:
|
|
64
|
+
"""
|
|
65
|
+
Evaluate the model.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
model: Model to evaluate.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dictionary with correlation scores.
|
|
72
|
+
"""
|
|
73
|
+
logger.info(f"Running {self.name} evaluation")
|
|
74
|
+
|
|
75
|
+
# Encode sentences
|
|
76
|
+
embeddings1 = model.encode(
|
|
77
|
+
self.sentences1,
|
|
78
|
+
batch_size=self.batch_size,
|
|
79
|
+
show_progress_bar=self.show_progress_bar,
|
|
80
|
+
convert_to_tensor=True,
|
|
81
|
+
)
|
|
82
|
+
embeddings2 = model.encode(
|
|
83
|
+
self.sentences2,
|
|
84
|
+
batch_size=self.batch_size,
|
|
85
|
+
show_progress_bar=self.show_progress_bar,
|
|
86
|
+
convert_to_tensor=True,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Compute similarities
|
|
90
|
+
if self.main_similarity == "cosine":
|
|
91
|
+
similarities = torch.nn.functional.cosine_similarity(
|
|
92
|
+
embeddings1, embeddings2
|
|
93
|
+
).cpu().numpy()
|
|
94
|
+
else:
|
|
95
|
+
# Euclidean distance (convert to similarity)
|
|
96
|
+
distances = torch.pairwise_distance(embeddings1, embeddings2).cpu().numpy()
|
|
97
|
+
similarities = 1 / (1 + distances)
|
|
98
|
+
|
|
99
|
+
# Compute correlations
|
|
100
|
+
pearson = pearsonr(similarities, self.scores)[0]
|
|
101
|
+
spearman = spearmanr(similarities, self.scores)[0]
|
|
102
|
+
|
|
103
|
+
results = {
|
|
104
|
+
f"{self.name}_pearson_{self.main_similarity}": pearson,
|
|
105
|
+
f"{self.name}_spearman_{self.main_similarity}": spearman,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
logger.info(f"{self.name}: Pearson={pearson:.4f}, Spearman={spearman:.4f}")
|
|
109
|
+
|
|
110
|
+
return results
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class MSEEvaluator:
|
|
114
|
+
"""
|
|
115
|
+
Evaluate MSE between student and teacher embeddings.
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
>>> evaluator = MSEEvaluator(
|
|
119
|
+
... source_sentences=["Hello", "World"],
|
|
120
|
+
... target_sentences=["Hello", "World"],
|
|
121
|
+
... teacher_model=teacher,
|
|
122
|
+
... )
|
|
123
|
+
>>> results = evaluator(student_model)
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
source_sentences: list[str],
|
|
129
|
+
target_sentences: list[str],
|
|
130
|
+
teacher_model: SentenceTransformer,
|
|
131
|
+
batch_size: int = 32,
|
|
132
|
+
name: str = "mse-eval",
|
|
133
|
+
show_progress_bar: bool = True,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Initialize the evaluator.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
source_sentences: Source sentences.
|
|
140
|
+
target_sentences: Target sentences (same as source for distillation).
|
|
141
|
+
teacher_model: Teacher model for comparison.
|
|
142
|
+
batch_size: Batch size for encoding.
|
|
143
|
+
name: Name for logging.
|
|
144
|
+
show_progress_bar: Whether to show progress bar.
|
|
145
|
+
"""
|
|
146
|
+
self.source_sentences = source_sentences
|
|
147
|
+
self.target_sentences = target_sentences
|
|
148
|
+
self.teacher_model = teacher_model
|
|
149
|
+
self.batch_size = batch_size
|
|
150
|
+
self.name = name
|
|
151
|
+
self.show_progress_bar = show_progress_bar
|
|
152
|
+
|
|
153
|
+
# Precompute teacher embeddings
|
|
154
|
+
with torch.no_grad():
|
|
155
|
+
self.teacher_embeddings = teacher_model.encode(
|
|
156
|
+
source_sentences,
|
|
157
|
+
batch_size=batch_size,
|
|
158
|
+
show_progress_bar=show_progress_bar,
|
|
159
|
+
convert_to_tensor=True,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def __call__(self, model: SentenceTransformer) -> dict[str, float]:
|
|
163
|
+
"""
|
|
164
|
+
Evaluate the model.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
model: Student model to evaluate.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Dictionary with MSE score.
|
|
171
|
+
"""
|
|
172
|
+
logger.info(f"Running {self.name} evaluation")
|
|
173
|
+
|
|
174
|
+
# Encode with student
|
|
175
|
+
student_embeddings = model.encode(
|
|
176
|
+
self.target_sentences,
|
|
177
|
+
batch_size=self.batch_size,
|
|
178
|
+
show_progress_bar=self.show_progress_bar,
|
|
179
|
+
convert_to_tensor=True,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Compute MSE
|
|
183
|
+
mse = torch.nn.functional.mse_loss(
|
|
184
|
+
student_embeddings, self.teacher_embeddings
|
|
185
|
+
).item()
|
|
186
|
+
|
|
187
|
+
# Compute cosine similarity
|
|
188
|
+
cosine_sim = torch.nn.functional.cosine_similarity(
|
|
189
|
+
student_embeddings, self.teacher_embeddings
|
|
190
|
+
).mean().item()
|
|
191
|
+
|
|
192
|
+
results = {
|
|
193
|
+
f"{self.name}_mse": mse,
|
|
194
|
+
f"{self.name}_cosine_similarity": cosine_sim,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
logger.info(f"{self.name}: MSE={mse:.6f}, Cosine={cosine_sim:.4f}")
|
|
198
|
+
|
|
199
|
+
return results
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class TranslationEvaluator:
|
|
203
|
+
"""
|
|
204
|
+
Evaluate multilingual alignment via translation retrieval.
|
|
205
|
+
|
|
206
|
+
Checks if source[i] embedding is closest to target[i].
|
|
207
|
+
|
|
208
|
+
Example:
|
|
209
|
+
>>> evaluator = TranslationEvaluator(
|
|
210
|
+
... source_sentences=["Hello", "World"],
|
|
211
|
+
... target_sentences=["Hallo", "Welt"],
|
|
212
|
+
... )
|
|
213
|
+
>>> results = evaluator(model)
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
source_sentences: list[str],
|
|
219
|
+
target_sentences: list[str],
|
|
220
|
+
batch_size: int = 32,
|
|
221
|
+
name: str = "translation-eval",
|
|
222
|
+
show_progress_bar: bool = True,
|
|
223
|
+
):
|
|
224
|
+
"""
|
|
225
|
+
Initialize the evaluator.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
source_sentences: Source language sentences.
|
|
229
|
+
target_sentences: Target language sentences (parallel).
|
|
230
|
+
batch_size: Batch size for encoding.
|
|
231
|
+
name: Name for logging.
|
|
232
|
+
show_progress_bar: Whether to show progress bar.
|
|
233
|
+
"""
|
|
234
|
+
self.source_sentences = source_sentences
|
|
235
|
+
self.target_sentences = target_sentences
|
|
236
|
+
self.batch_size = batch_size
|
|
237
|
+
self.name = name
|
|
238
|
+
self.show_progress_bar = show_progress_bar
|
|
239
|
+
|
|
240
|
+
def __call__(self, model: SentenceTransformer) -> dict[str, float]:
|
|
241
|
+
"""
|
|
242
|
+
Evaluate the model.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model: Model to evaluate.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Dictionary with retrieval accuracy scores.
|
|
249
|
+
"""
|
|
250
|
+
logger.info(f"Running {self.name} evaluation")
|
|
251
|
+
|
|
252
|
+
# Encode both sets
|
|
253
|
+
source_embeddings = model.encode(
|
|
254
|
+
self.source_sentences,
|
|
255
|
+
batch_size=self.batch_size,
|
|
256
|
+
show_progress_bar=self.show_progress_bar,
|
|
257
|
+
convert_to_tensor=True,
|
|
258
|
+
)
|
|
259
|
+
target_embeddings = model.encode(
|
|
260
|
+
self.target_sentences,
|
|
261
|
+
batch_size=self.batch_size,
|
|
262
|
+
show_progress_bar=self.show_progress_bar,
|
|
263
|
+
convert_to_tensor=True,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Normalize for cosine similarity
|
|
267
|
+
source_embeddings = torch.nn.functional.normalize(source_embeddings, p=2, dim=1)
|
|
268
|
+
target_embeddings = torch.nn.functional.normalize(target_embeddings, p=2, dim=1)
|
|
269
|
+
|
|
270
|
+
# Compute similarity matrix
|
|
271
|
+
similarity_matrix = torch.mm(source_embeddings, target_embeddings.t())
|
|
272
|
+
|
|
273
|
+
# Get rankings
|
|
274
|
+
correct_indices = torch.arange(len(self.source_sentences), device=similarity_matrix.device)
|
|
275
|
+
rankings = (similarity_matrix.argsort(dim=1, descending=True) == correct_indices.unsqueeze(1)).nonzero()[:, 1]
|
|
276
|
+
|
|
277
|
+
# Compute metrics
|
|
278
|
+
acc_at_1 = (rankings == 0).float().mean().item()
|
|
279
|
+
acc_at_5 = (rankings < 5).float().mean().item()
|
|
280
|
+
acc_at_10 = (rankings < 10).float().mean().item()
|
|
281
|
+
mrr = (1.0 / (rankings.float() + 1)).mean().item()
|
|
282
|
+
|
|
283
|
+
results = {
|
|
284
|
+
f"{self.name}_accuracy@1": acc_at_1,
|
|
285
|
+
f"{self.name}_accuracy@5": acc_at_5,
|
|
286
|
+
f"{self.name}_accuracy@10": acc_at_10,
|
|
287
|
+
f"{self.name}_mrr": mrr,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
logger.info(f"{self.name}: Acc@1={acc_at_1:.4f}, MRR={mrr:.4f}")
|
|
291
|
+
|
|
292
|
+
return results
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class SequentialEvaluator:
|
|
296
|
+
"""
|
|
297
|
+
Run multiple evaluators sequentially.
|
|
298
|
+
|
|
299
|
+
Example:
|
|
300
|
+
>>> evaluator = SequentialEvaluator([evaluator1, evaluator2])
|
|
301
|
+
>>> results = evaluator(model)
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
evaluators: list[Any],
|
|
307
|
+
main_score_function: Callable[[dict], float] | None = None,
|
|
308
|
+
):
|
|
309
|
+
"""
|
|
310
|
+
Initialize the sequential evaluator.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
evaluators: List of evaluators to run.
|
|
314
|
+
main_score_function: Function to compute main score from results.
|
|
315
|
+
"""
|
|
316
|
+
self.evaluators = evaluators
|
|
317
|
+
self.main_score_function = main_score_function or (
|
|
318
|
+
lambda x: np.mean([v for v in x.values() if isinstance(v, (int, float))])
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def __call__(self, model: SentenceTransformer) -> dict[str, float]:
|
|
322
|
+
"""
|
|
323
|
+
Run all evaluators.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
model: Model to evaluate.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Combined dictionary of all evaluation results.
|
|
330
|
+
"""
|
|
331
|
+
all_results = {}
|
|
332
|
+
|
|
333
|
+
for evaluator in self.evaluators:
|
|
334
|
+
try:
|
|
335
|
+
results = evaluator(model)
|
|
336
|
+
all_results.update(results)
|
|
337
|
+
except Exception as e:
|
|
338
|
+
logger.warning(f"Evaluator {evaluator} failed: {e}")
|
|
339
|
+
|
|
340
|
+
# Compute main score
|
|
341
|
+
all_results["main_score"] = self.main_score_function(all_results)
|
|
342
|
+
|
|
343
|
+
return all_results
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Metrics for evaluating distillation quality."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DistillationMetrics:
|
|
12
|
+
"""Metrics for evaluating distillation quality."""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def embedding_mse(
|
|
16
|
+
student_embeddings: torch.Tensor,
|
|
17
|
+
teacher_embeddings: torch.Tensor,
|
|
18
|
+
) -> float:
|
|
19
|
+
"""Mean squared error between embeddings."""
|
|
20
|
+
return torch.nn.functional.mse_loss(student_embeddings, teacher_embeddings).item()
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def embedding_cosine_similarity(
|
|
24
|
+
student_embeddings: torch.Tensor,
|
|
25
|
+
teacher_embeddings: torch.Tensor,
|
|
26
|
+
) -> float:
|
|
27
|
+
"""Average cosine similarity between embeddings."""
|
|
28
|
+
return torch.nn.functional.cosine_similarity(
|
|
29
|
+
student_embeddings, teacher_embeddings, dim=-1
|
|
30
|
+
).mean().item()
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def compression_ratio(
|
|
34
|
+
student_model: SentenceTransformer | PreTrainedModel,
|
|
35
|
+
teacher_model: SentenceTransformer | PreTrainedModel,
|
|
36
|
+
) -> float:
|
|
37
|
+
"""Parameter count ratio (teacher / student)."""
|
|
38
|
+
student_params = sum(p.numel() for p in student_model.parameters())
|
|
39
|
+
teacher_params = sum(p.numel() for p in teacher_model.parameters())
|
|
40
|
+
return teacher_params / student_params
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def speedup_factor(
|
|
44
|
+
student_model: SentenceTransformer | PreTrainedModel,
|
|
45
|
+
teacher_model: SentenceTransformer | PreTrainedModel,
|
|
46
|
+
input_batch: dict | list[str],
|
|
47
|
+
num_runs: int = 10,
|
|
48
|
+
) -> float:
|
|
49
|
+
"""Inference speedup of student over teacher."""
|
|
50
|
+
import time
|
|
51
|
+
|
|
52
|
+
# Warm up
|
|
53
|
+
if isinstance(student_model, SentenceTransformer):
|
|
54
|
+
student_model.encode(input_batch[:1])
|
|
55
|
+
teacher_model.encode(input_batch[:1])
|
|
56
|
+
|
|
57
|
+
# Time student
|
|
58
|
+
start = time.time()
|
|
59
|
+
for _ in range(num_runs):
|
|
60
|
+
if isinstance(student_model, SentenceTransformer):
|
|
61
|
+
student_model.encode(input_batch)
|
|
62
|
+
else:
|
|
63
|
+
student_model(**input_batch)
|
|
64
|
+
student_time = time.time() - start
|
|
65
|
+
|
|
66
|
+
# Time teacher
|
|
67
|
+
start = time.time()
|
|
68
|
+
for _ in range(num_runs):
|
|
69
|
+
if isinstance(teacher_model, SentenceTransformer):
|
|
70
|
+
teacher_model.encode(input_batch)
|
|
71
|
+
else:
|
|
72
|
+
teacher_model(**input_batch)
|
|
73
|
+
teacher_time = time.time() - start
|
|
74
|
+
|
|
75
|
+
return teacher_time / student_time
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""Custom layers for distillation models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DenseProjection(nn.Module):
|
|
10
|
+
"""
|
|
11
|
+
Dense projection layer for dimension reduction.
|
|
12
|
+
|
|
13
|
+
Used to project teacher embeddings to student dimension via PCA weights.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
>>> projection = DenseProjection(in_features=768, out_features=256)
|
|
17
|
+
>>> reduced = projection(embeddings)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
in_features: int,
|
|
23
|
+
out_features: int,
|
|
24
|
+
bias: bool = False,
|
|
25
|
+
weights: torch.Tensor | None = None,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initialize the projection layer.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
in_features: Input dimension (teacher).
|
|
32
|
+
out_features: Output dimension (student).
|
|
33
|
+
bias: Whether to include bias term.
|
|
34
|
+
weights: Optional initial weights (e.g., from PCA).
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
|
|
38
|
+
self.in_features = in_features
|
|
39
|
+
self.out_features = out_features
|
|
40
|
+
|
|
41
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
42
|
+
|
|
43
|
+
if weights is not None:
|
|
44
|
+
with torch.no_grad():
|
|
45
|
+
self.linear.weight.copy_(weights)
|
|
46
|
+
|
|
47
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
48
|
+
"""Apply projection."""
|
|
49
|
+
return self.linear(x)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class PoolingLayer(nn.Module):
|
|
53
|
+
"""
|
|
54
|
+
Pooling layer for converting token embeddings to sentence embeddings.
|
|
55
|
+
|
|
56
|
+
Supports multiple pooling strategies.
|
|
57
|
+
|
|
58
|
+
Example:
|
|
59
|
+
>>> pooler = PoolingLayer(pooling_mode="mean")
|
|
60
|
+
>>> sentence_embedding = pooler(token_embeddings, attention_mask)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
pooling_mode: str = "mean",
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize the pooling layer.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
pooling_mode: One of "mean", "cls", "max", "weighted_mean".
|
|
72
|
+
"""
|
|
73
|
+
super().__init__()
|
|
74
|
+
self.pooling_mode = pooling_mode
|
|
75
|
+
|
|
76
|
+
def forward(
|
|
77
|
+
self,
|
|
78
|
+
token_embeddings: torch.Tensor,
|
|
79
|
+
attention_mask: torch.Tensor,
|
|
80
|
+
) -> torch.Tensor:
|
|
81
|
+
"""
|
|
82
|
+
Apply pooling to token embeddings.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
token_embeddings: Token embeddings [batch, seq_len, dim].
|
|
86
|
+
attention_mask: Attention mask [batch, seq_len].
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Pooled embeddings [batch, dim].
|
|
90
|
+
"""
|
|
91
|
+
if self.pooling_mode == "cls":
|
|
92
|
+
return token_embeddings[:, 0]
|
|
93
|
+
|
|
94
|
+
elif self.pooling_mode == "max":
|
|
95
|
+
# Mask out padding tokens
|
|
96
|
+
token_embeddings = token_embeddings.masked_fill(
|
|
97
|
+
~attention_mask.unsqueeze(-1).bool(), float("-inf")
|
|
98
|
+
)
|
|
99
|
+
return token_embeddings.max(dim=1).values
|
|
100
|
+
|
|
101
|
+
elif self.pooling_mode == "weighted_mean":
|
|
102
|
+
# Position-weighted mean
|
|
103
|
+
weights = torch.arange(
|
|
104
|
+
1, token_embeddings.size(1) + 1,
|
|
105
|
+
device=token_embeddings.device,
|
|
106
|
+
dtype=token_embeddings.dtype,
|
|
107
|
+
)
|
|
108
|
+
weights = weights.unsqueeze(0).unsqueeze(-1) * attention_mask.unsqueeze(-1)
|
|
109
|
+
return (token_embeddings * weights).sum(dim=1) / weights.sum(dim=1)
|
|
110
|
+
|
|
111
|
+
else: # mean pooling
|
|
112
|
+
input_mask_expanded = attention_mask.unsqueeze(-1).float()
|
|
113
|
+
sum_embeddings = (token_embeddings * input_mask_expanded).sum(dim=1)
|
|
114
|
+
sum_mask = input_mask_expanded.sum(dim=1).clamp(min=1e-9)
|
|
115
|
+
return sum_embeddings / sum_mask
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Pruning module for model compression."""
|
|
2
|
+
|
|
3
|
+
from distil_trainer.pruning.depth_pruning import DepthPruner
|
|
4
|
+
from distil_trainer.pruning.width_pruning import WidthPruner
|
|
5
|
+
from distil_trainer.pruning.combined_pruning import CombinedPruner
|
|
6
|
+
from distil_trainer.pruning.importance import ImportanceEstimator
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"DepthPruner",
|
|
10
|
+
"WidthPruner",
|
|
11
|
+
"CombinedPruner",
|
|
12
|
+
"ImportanceEstimator",
|
|
13
|
+
]
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Combined depth and width pruning."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from transformers import PreTrainedModel
|
|
9
|
+
|
|
10
|
+
from sentence_transformers import SentenceTransformer
|
|
11
|
+
|
|
12
|
+
from distil_trainer.core.config import CombinedPruningConfig, LayerReductionConfig, WidthPruningConfig
|
|
13
|
+
from distil_trainer.pruning.depth_pruning import DepthPruner
|
|
14
|
+
from distil_trainer.pruning.width_pruning import WidthPruner
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CombinedPruner:
|
|
20
|
+
"""
|
|
21
|
+
Combines depth and width pruning for maximum compression.
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
>>> config = CombinedPruningConfig(
|
|
25
|
+
... depth_config=LayerReductionConfig(num_layers_to_keep=8),
|
|
26
|
+
... width_config=WidthPruningConfig(target_hidden_size=3072),
|
|
27
|
+
... pruning_order="depth_first"
|
|
28
|
+
... )
|
|
29
|
+
>>> pruner = CombinedPruner(model)
|
|
30
|
+
>>> pruned_model = pruner.prune(config)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
36
|
+
calibration_data: list[str] | None = None,
|
|
37
|
+
):
|
|
38
|
+
"""
|
|
39
|
+
Initialize the CombinedPruner.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model: The model to prune.
|
|
43
|
+
calibration_data: Optional calibration data for importance estimation.
|
|
44
|
+
"""
|
|
45
|
+
self.model = model
|
|
46
|
+
self.calibration_data = calibration_data
|
|
47
|
+
|
|
48
|
+
def prune(
|
|
49
|
+
self,
|
|
50
|
+
config: CombinedPruningConfig,
|
|
51
|
+
calibration_data: list[str] | None = None,
|
|
52
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
53
|
+
"""
|
|
54
|
+
Apply combined pruning based on configuration.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
config: Combined pruning configuration.
|
|
58
|
+
calibration_data: Optional calibration data.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
New model with both depth and width reduced.
|
|
62
|
+
"""
|
|
63
|
+
data = calibration_data or self.calibration_data
|
|
64
|
+
model = self.model
|
|
65
|
+
|
|
66
|
+
if config.pruning_order == "depth_first":
|
|
67
|
+
model = self._apply_depth_pruning(model, config.depth_config, data)
|
|
68
|
+
model = self._apply_width_pruning(model, config.width_config, data)
|
|
69
|
+
elif config.pruning_order == "width_first":
|
|
70
|
+
model = self._apply_width_pruning(model, config.width_config, data)
|
|
71
|
+
model = self._apply_depth_pruning(model, config.depth_config, data)
|
|
72
|
+
elif config.pruning_order == "interleaved":
|
|
73
|
+
# Apply in iterations, alternating
|
|
74
|
+
for i in range(config.num_iterations):
|
|
75
|
+
logger.info(f"Pruning iteration {i + 1}/{config.num_iterations}")
|
|
76
|
+
if i % 2 == 0:
|
|
77
|
+
model = self._apply_depth_pruning(model, config.depth_config, data)
|
|
78
|
+
else:
|
|
79
|
+
model = self._apply_width_pruning(model, config.width_config, data)
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Unknown pruning order: {config.pruning_order}")
|
|
82
|
+
|
|
83
|
+
# Log final statistics
|
|
84
|
+
original_params = sum(p.numel() for p in self.model.parameters())
|
|
85
|
+
final_params = sum(p.numel() for p in model.parameters())
|
|
86
|
+
reduction = 1 - (final_params / original_params)
|
|
87
|
+
logger.info(f"Total parameter reduction: {original_params:,} -> {final_params:,} ({reduction:.1%})")
|
|
88
|
+
|
|
89
|
+
return model
|
|
90
|
+
|
|
91
|
+
def _apply_depth_pruning(
|
|
92
|
+
self,
|
|
93
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
94
|
+
config: LayerReductionConfig | None,
|
|
95
|
+
calibration_data: list[str] | None,
|
|
96
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
97
|
+
"""Apply depth pruning if configured."""
|
|
98
|
+
if config is None:
|
|
99
|
+
return model
|
|
100
|
+
|
|
101
|
+
logger.info("Applying depth pruning...")
|
|
102
|
+
pruner = DepthPruner(model, calibration_data)
|
|
103
|
+
return pruner.prune(
|
|
104
|
+
layers_to_keep=config.layers_to_keep,
|
|
105
|
+
num_layers_to_keep=config.num_layers_to_keep,
|
|
106
|
+
layers_to_drop=config.layers_to_drop,
|
|
107
|
+
layer_selection=config.layer_selection,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def _apply_width_pruning(
|
|
111
|
+
self,
|
|
112
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
113
|
+
config: WidthPruningConfig | None,
|
|
114
|
+
calibration_data: list[str] | None,
|
|
115
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
116
|
+
"""Apply width pruning if configured."""
|
|
117
|
+
if config is None:
|
|
118
|
+
return model
|
|
119
|
+
|
|
120
|
+
logger.info("Applying width pruning...")
|
|
121
|
+
pruner = WidthPruner(model, calibration_data)
|
|
122
|
+
return pruner.prune(config, calibration_data)
|