distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: distil-trainer
|
|
3
|
+
Version: 0.1.10
|
|
4
|
+
Summary: A comprehensive knowledge distillation training framework for transformer models
|
|
5
|
+
Project-URL: Homepage, https://github.com/malibayram/distil-trainer
|
|
6
|
+
Project-URL: Repository, https://github.com/malibayram/distil-trainer
|
|
7
|
+
Project-URL: Issues, https://github.com/malibayram/distil-trainer/issues
|
|
8
|
+
Author-email: Ali Bayram <malibayram20@gmail.com>
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: deep-learning,distillation,knowledge-distillation,model-compression,nlp,pruning,pytorch,sentence-transformers,transformers
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Operating System :: OS Independent
|
|
17
|
+
Classifier: Programming Language :: Python :: 3
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Requires-Dist: accelerate>=0.24.0
|
|
25
|
+
Requires-Dist: datasets>=2.14.0
|
|
26
|
+
Requires-Dist: numpy>=1.24.0
|
|
27
|
+
Requires-Dist: pandas>=2.0.0
|
|
28
|
+
Requires-Dist: protobuf
|
|
29
|
+
Requires-Dist: scikit-learn>=1.0.0
|
|
30
|
+
Requires-Dist: sentence-transformers>=2.2.0
|
|
31
|
+
Requires-Dist: torch>=2.0.0
|
|
32
|
+
Requires-Dist: tqdm>=4.65.0
|
|
33
|
+
Requires-Dist: transformers>=4.35.0
|
|
34
|
+
Provides-Extra: all
|
|
35
|
+
Requires-Dist: bitsandbytes>=0.41.0; extra == 'all'
|
|
36
|
+
Requires-Dist: mteb>=1.0.0; extra == 'all'
|
|
37
|
+
Requires-Dist: onnx>=1.14.0; extra == 'all'
|
|
38
|
+
Requires-Dist: onnxruntime>=1.16.0; extra == 'all'
|
|
39
|
+
Requires-Dist: peft>=0.6.0; extra == 'all'
|
|
40
|
+
Requires-Dist: tensorboard>=2.14.0; extra == 'all'
|
|
41
|
+
Requires-Dist: wandb>=0.15.0; extra == 'all'
|
|
42
|
+
Provides-Extra: dev
|
|
43
|
+
Requires-Dist: black>=23.0.0; extra == 'dev'
|
|
44
|
+
Requires-Dist: mypy>=1.0.0; extra == 'dev'
|
|
45
|
+
Requires-Dist: pre-commit>=3.0.0; extra == 'dev'
|
|
46
|
+
Requires-Dist: pytest-cov>=4.0.0; extra == 'dev'
|
|
47
|
+
Requires-Dist: pytest>=7.0.0; extra == 'dev'
|
|
48
|
+
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
49
|
+
Provides-Extra: evaluation
|
|
50
|
+
Requires-Dist: mteb>=1.0.0; extra == 'evaluation'
|
|
51
|
+
Provides-Extra: export
|
|
52
|
+
Requires-Dist: onnx>=1.14.0; extra == 'export'
|
|
53
|
+
Requires-Dist: onnxruntime>=1.16.0; extra == 'export'
|
|
54
|
+
Provides-Extra: peft
|
|
55
|
+
Requires-Dist: peft>=0.6.0; extra == 'peft'
|
|
56
|
+
Provides-Extra: quantization
|
|
57
|
+
Requires-Dist: bitsandbytes>=0.41.0; extra == 'quantization'
|
|
58
|
+
Provides-Extra: tracking
|
|
59
|
+
Requires-Dist: tensorboard>=2.14.0; extra == 'tracking'
|
|
60
|
+
Requires-Dist: wandb>=0.15.0; extra == 'tracking'
|
|
61
|
+
Description-Content-Type: text/markdown
|
|
62
|
+
|
|
63
|
+
# Distil Trainer
|
|
64
|
+
|
|
65
|
+
A comprehensive knowledge distillation training framework for transformer models.
|
|
66
|
+
|
|
67
|
+
[](https://badge.fury.io/py/distil-trainer)
|
|
68
|
+
[](https://opensource.org/licenses/MIT)
|
|
69
|
+
[](https://www.python.org/downloads/)
|
|
70
|
+
|
|
71
|
+
## Features
|
|
72
|
+
|
|
73
|
+
- **7 Distillation Strategies**:
|
|
74
|
+
|
|
75
|
+
- Classical Embedding Distillation (MSE/Cosine loss)
|
|
76
|
+
- Layer Reduction (Depth Pruning)
|
|
77
|
+
- Width Pruning (Hidden size, Attention heads, MLP)
|
|
78
|
+
- Combined Depth-Width Pruning
|
|
79
|
+
- Multilingual Model Extension
|
|
80
|
+
- LLM to Embedding Model Conversion
|
|
81
|
+
- Reasoning/Chain-of-Thought Distillation
|
|
82
|
+
|
|
83
|
+
- **Flexible Architecture**:
|
|
84
|
+
|
|
85
|
+
- Support for SentenceTransformers and HuggingFace models
|
|
86
|
+
- Multiple loss functions (MSE, KL Divergence, Cosine, Ranking)
|
|
87
|
+
- Configurable importance estimation for pruning
|
|
88
|
+
- PCA projection for dimension reduction
|
|
89
|
+
|
|
90
|
+
- **Production Ready**:
|
|
91
|
+
- Export to HuggingFace Hub
|
|
92
|
+
- ONNX export support
|
|
93
|
+
- Distributed training with Accelerate
|
|
94
|
+
- Comprehensive evaluation framework
|
|
95
|
+
|
|
96
|
+
## Installation
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
pip install distil-trainer
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
With optional dependencies:
|
|
103
|
+
|
|
104
|
+
```bash
|
|
105
|
+
# For experiment tracking
|
|
106
|
+
pip install distil-trainer[tracking]
|
|
107
|
+
|
|
108
|
+
# For model export
|
|
109
|
+
pip install distil-trainer[export]
|
|
110
|
+
|
|
111
|
+
# For MTEB evaluation
|
|
112
|
+
pip install distil-trainer[evaluation]
|
|
113
|
+
|
|
114
|
+
# For all features
|
|
115
|
+
pip install distil-trainer[all]
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
## Quick Start
|
|
119
|
+
|
|
120
|
+
### Basic Embedding Distillation
|
|
121
|
+
|
|
122
|
+
Distill knowledge from a large teacher model to a smaller student model:
|
|
123
|
+
|
|
124
|
+
```python
|
|
125
|
+
from distil_trainer import DistilTrainer, DistilTrainerConfig, DistillationConfig
|
|
126
|
+
|
|
127
|
+
config = DistilTrainerConfig(
|
|
128
|
+
teacher_model="sentence-transformers/all-mpnet-base-v2",
|
|
129
|
+
student_model="sentence-transformers/paraphrase-TinyBERT-L6-v2",
|
|
130
|
+
distillation_config=DistillationConfig(
|
|
131
|
+
loss_type="mse", # Options: mse, kl_divergence, cosine, ranking, combined
|
|
132
|
+
use_pca_projection=True, # When student dim < teacher dim
|
|
133
|
+
),
|
|
134
|
+
output_dir="./distilled_model",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
trainer = DistilTrainer(config)
|
|
138
|
+
trainer.load_data(train_data="sentence-transformers/all-nli")
|
|
139
|
+
trainer.train()
|
|
140
|
+
trainer.save_model("./final_model")
|
|
141
|
+
trainer.save_model("./final_model")
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
### Custom Dataset Columns
|
|
145
|
+
|
|
146
|
+
If your dataset uses different column names (e.g., "text" instead of "sentence"), you can specify this when loading data:
|
|
147
|
+
|
|
148
|
+
```python
|
|
149
|
+
# Load dataset with custom text column
|
|
150
|
+
trainer.load_data(
|
|
151
|
+
train_data="alibayram/cosmos-corpus-00-5",
|
|
152
|
+
text_column="text"
|
|
153
|
+
)
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
### Layer Reduction (Depth Pruning)
|
|
157
|
+
|
|
158
|
+
Reduce model depth by keeping only selected layers:
|
|
159
|
+
|
|
160
|
+
```python
|
|
161
|
+
from distil_trainer import DistilTrainer, DistilTrainerConfig
|
|
162
|
+
from distil_trainer.core.config import LayerReductionConfig
|
|
163
|
+
|
|
164
|
+
config = DistilTrainerConfig(
|
|
165
|
+
teacher_model="mixedbread-ai/mxbai-embed-large-v1",
|
|
166
|
+
student_init_strategy="layer_reduction",
|
|
167
|
+
pruning_config=LayerReductionConfig(
|
|
168
|
+
# Explicitly specify layers to keep (0-indexed)
|
|
169
|
+
layers_to_keep=[0, 3, 6, 9, 12, 15, 18, 21],
|
|
170
|
+
# Or use automatic selection
|
|
171
|
+
# num_layers_to_keep=8,
|
|
172
|
+
# layer_selection="importance", # Options: first, last, even, importance, custom
|
|
173
|
+
),
|
|
174
|
+
output_dir="./layer_reduced_model",
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
trainer = DistilTrainer(config)
|
|
178
|
+
trainer.train()
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
### Width Pruning
|
|
182
|
+
|
|
183
|
+
Reduce hidden dimensions, attention heads, and intermediate sizes:
|
|
184
|
+
|
|
185
|
+
```python
|
|
186
|
+
from distil_trainer import DistilTrainer, DistilTrainerConfig
|
|
187
|
+
from distil_trainer.core.config import WidthPruningConfig
|
|
188
|
+
|
|
189
|
+
config = DistilTrainerConfig(
|
|
190
|
+
teacher_model="Qwen/Qwen3-8B",
|
|
191
|
+
student_init_strategy="width_pruning",
|
|
192
|
+
pruning_config=WidthPruningConfig(
|
|
193
|
+
# Target absolute dimensions
|
|
194
|
+
target_hidden_size=3072,
|
|
195
|
+
target_intermediate_size=9216,
|
|
196
|
+
target_num_attention_heads=24,
|
|
197
|
+
target_num_key_value_heads=4,
|
|
198
|
+
# Or use ratios (alternative to absolute values)
|
|
199
|
+
# hidden_size_ratio=0.75,
|
|
200
|
+
# intermediate_size_ratio=0.75,
|
|
201
|
+
# Importance estimation method
|
|
202
|
+
importance_method="activation", # Options: activation, gradient, taylor, wanda, cosine_similarity
|
|
203
|
+
calibration_samples=1024,
|
|
204
|
+
),
|
|
205
|
+
output_dir="./width_pruned_model",
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
trainer = DistilTrainer(config)
|
|
209
|
+
trainer.train()
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
### Combined Depth-Width Pruning
|
|
213
|
+
|
|
214
|
+
Apply both depth and width pruning for maximum compression:
|
|
215
|
+
|
|
216
|
+
```python
|
|
217
|
+
from distil_trainer import DistilTrainer, DistilTrainerConfig
|
|
218
|
+
from distil_trainer.core.config import (
|
|
219
|
+
CombinedPruningConfig,
|
|
220
|
+
LayerReductionConfig,
|
|
221
|
+
WidthPruningConfig,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
config = DistilTrainerConfig(
|
|
225
|
+
teacher_model="meta-llama/Llama-3.2-3B",
|
|
226
|
+
student_init_strategy="combined_pruning",
|
|
227
|
+
pruning_config=CombinedPruningConfig(
|
|
228
|
+
depth_config=LayerReductionConfig(
|
|
229
|
+
num_layers_to_keep=16,
|
|
230
|
+
layer_selection="importance",
|
|
231
|
+
),
|
|
232
|
+
width_config=WidthPruningConfig(
|
|
233
|
+
hidden_size_ratio=0.75,
|
|
234
|
+
intermediate_size_ratio=0.75,
|
|
235
|
+
),
|
|
236
|
+
pruning_order="depth_first", # Options: depth_first, width_first, interleaved
|
|
237
|
+
num_iterations=1,
|
|
238
|
+
),
|
|
239
|
+
output_dir="./compressed_model",
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
trainer = DistilTrainer(config)
|
|
243
|
+
trainer.train()
|
|
244
|
+
```
|
|
245
|
+
|
|
246
|
+
### Multilingual Model Extension
|
|
247
|
+
|
|
248
|
+
Extend a monolingual model to support multiple languages:
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
from distil_trainer.core.config import MultilingualConfig
|
|
252
|
+
from distil_trainer.distillation import MultilingualDistillationStrategy
|
|
253
|
+
|
|
254
|
+
config = MultilingualConfig(
|
|
255
|
+
# Teacher understands these languages
|
|
256
|
+
source_languages=["en"],
|
|
257
|
+
# Student should learn these languages
|
|
258
|
+
target_languages=["de", "es", "fr", "it", "pt", "zh", "ja", "ko"],
|
|
259
|
+
# Student model (multilingual encoder)
|
|
260
|
+
student_model="xlm-roberta-base",
|
|
261
|
+
student_max_seq_length=128,
|
|
262
|
+
# Parallel sentence datasets for training
|
|
263
|
+
parallel_datasets=[
|
|
264
|
+
"sentence-transformers/parallel-sentences-talks",
|
|
265
|
+
"sentence-transformers/parallel-sentences-tatoeba",
|
|
266
|
+
],
|
|
267
|
+
max_sentences_per_language=500000,
|
|
268
|
+
# Training settings
|
|
269
|
+
num_train_epochs=5,
|
|
270
|
+
evaluation_steps=5000,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
strategy = MultilingualDistillationStrategy(
|
|
274
|
+
teacher_model="paraphrase-distilroberta-base-v2",
|
|
275
|
+
config=config,
|
|
276
|
+
)
|
|
277
|
+
strategy.train()
|
|
278
|
+
```
|
|
279
|
+
|
|
280
|
+
## Configuration Reference
|
|
281
|
+
|
|
282
|
+
### DistilTrainerConfig
|
|
283
|
+
|
|
284
|
+
Main configuration class for distillation training:
|
|
285
|
+
|
|
286
|
+
| Parameter | Type | Default | Description |
|
|
287
|
+
| ----------------------- | ------------------ | ------------------- | ----------------------------------------- |
|
|
288
|
+
| `teacher_model` | str/Model | Required | Teacher model name or instance |
|
|
289
|
+
| `student_model` | str/Model | None | Student model (None creates from teacher) |
|
|
290
|
+
| `student_init_strategy` | str | "from_pretrained" | How to initialize student |
|
|
291
|
+
| `pruning_config` | PruningConfig | None | Pruning configuration |
|
|
292
|
+
| `distillation_config` | DistillationConfig | Default | Distillation loss settings |
|
|
293
|
+
| `training_config` | TrainingConfig | Default | Training hyperparameters |
|
|
294
|
+
| `output_dir` | str | "./distilled_model" | Output directory |
|
|
295
|
+
| `device` | str | "auto" | Device to use |
|
|
296
|
+
| `precision` | str | "bf16" | Training precision |
|
|
297
|
+
|
|
298
|
+
### DistillationConfig
|
|
299
|
+
|
|
300
|
+
Configuration for distillation losses:
|
|
301
|
+
|
|
302
|
+
| Parameter | Type | Default | Description |
|
|
303
|
+
| ------------------------------- | ----- | ------- | ------------------------------------------------------------ |
|
|
304
|
+
| `loss_type` | str | "mse" | Loss function: mse, kl_divergence, cosine, ranking, combined |
|
|
305
|
+
| `logit_loss_weight` | float | 1.0 | Weight for logit distillation loss |
|
|
306
|
+
| `embedding_loss_weight` | float | 1.0 | Weight for embedding distillation loss |
|
|
307
|
+
| `intermediate_loss_weight` | float | 0.0 | Weight for intermediate layer loss |
|
|
308
|
+
| `temperature` | float | 1.0 | Temperature for KL divergence |
|
|
309
|
+
| `use_pca_projection` | bool | True | Use PCA when student dim < teacher dim |
|
|
310
|
+
| `precompute_teacher_embeddings` | bool | True | Cache teacher embeddings |
|
|
311
|
+
|
|
312
|
+
### TrainingConfig
|
|
313
|
+
|
|
314
|
+
Training hyperparameters:
|
|
315
|
+
|
|
316
|
+
| Parameter | Type | Default | Description |
|
|
317
|
+
| ----------------------------- | ----- | -------- | ------------------------- |
|
|
318
|
+
| `num_train_epochs` | int | 1 | Number of training epochs |
|
|
319
|
+
| `per_device_train_batch_size` | int | 64 | Batch size per device |
|
|
320
|
+
| `learning_rate` | float | 1e-4 | Initial learning rate |
|
|
321
|
+
| `lr_scheduler_type` | str | "cosine" | LR scheduler type |
|
|
322
|
+
| `warmup_ratio` | float | 0.1 | Warmup ratio |
|
|
323
|
+
| `weight_decay` | float | 0.01 | Weight decay |
|
|
324
|
+
| `max_grad_norm` | float | 1.0 | Gradient clipping |
|
|
325
|
+
| `eval_strategy` | str | "steps" | Evaluation strategy |
|
|
326
|
+
| `eval_steps` | int | 500 | Evaluation frequency |
|
|
327
|
+
|
|
328
|
+
## Evaluation
|
|
329
|
+
|
|
330
|
+
The framework includes comprehensive evaluation tools:
|
|
331
|
+
|
|
332
|
+
```python
|
|
333
|
+
from distil_trainer.evaluation import (
|
|
334
|
+
EmbeddingSimilarityEvaluator,
|
|
335
|
+
MSEEvaluator,
|
|
336
|
+
BenchmarkRunner,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Evaluate embedding similarity between teacher and student
|
|
340
|
+
evaluator = EmbeddingSimilarityEvaluator(
|
|
341
|
+
teacher_model=teacher,
|
|
342
|
+
student_model=student,
|
|
343
|
+
)
|
|
344
|
+
results = evaluator.evaluate(test_sentences)
|
|
345
|
+
|
|
346
|
+
# Run MTEB benchmarks (requires distil-trainer[evaluation])
|
|
347
|
+
runner = BenchmarkRunner(
|
|
348
|
+
model="./distilled_model",
|
|
349
|
+
tasks=["STS12", "STS13", "STS14", "STS15", "STS16"],
|
|
350
|
+
)
|
|
351
|
+
benchmark_results = runner.run()
|
|
352
|
+
```
|
|
353
|
+
|
|
354
|
+
## Data Loading
|
|
355
|
+
|
|
356
|
+
Support for multiple data formats:
|
|
357
|
+
|
|
358
|
+
```python
|
|
359
|
+
from distil_trainer.data import (
|
|
360
|
+
DistillationDataModule,
|
|
361
|
+
SentenceDistillationDataset,
|
|
362
|
+
TripletDataset,
|
|
363
|
+
ParallelSentencesDataset,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Load from HuggingFace datasets
|
|
367
|
+
data_module = DistillationDataModule(
|
|
368
|
+
train_data="sentence-transformers/all-nli",
|
|
369
|
+
eval_data="sentence-transformers/stsb",
|
|
370
|
+
text_column="sentence",
|
|
371
|
+
max_seq_length=512,
|
|
372
|
+
num_workers=4,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Or use triplet format for ranking loss
|
|
376
|
+
triplet_dataset = TripletDataset(
|
|
377
|
+
data_path="path/to/triplets.jsonl",
|
|
378
|
+
query_column="query",
|
|
379
|
+
positive_column="positive",
|
|
380
|
+
negative_column="negative",
|
|
381
|
+
)
|
|
382
|
+
```
|
|
383
|
+
|
|
384
|
+
## Best Practices
|
|
385
|
+
|
|
386
|
+
Based on NVIDIA Minitron research:
|
|
387
|
+
|
|
388
|
+
1. **Sizing**: Train largest model first, then prune and distill iteratively
|
|
389
|
+
2. **Pruning**: Prefer width over depth pruning for better accuracy
|
|
390
|
+
3. **Retraining**: Use distillation loss exclusively (not conventional training)
|
|
391
|
+
4. **Loss Selection**:
|
|
392
|
+
- Logit + intermediate + embedding when depth is reduced significantly
|
|
393
|
+
- Logit-only when depth isn't reduced significantly
|
|
394
|
+
|
|
395
|
+
### Recommended Workflow
|
|
396
|
+
|
|
397
|
+
```python
|
|
398
|
+
# Step 1: Start with importance estimation
|
|
399
|
+
from distil_trainer.pruning import ImportanceEstimator
|
|
400
|
+
|
|
401
|
+
estimator = ImportanceEstimator(model, method="activation")
|
|
402
|
+
importance_scores = estimator.estimate(calibration_data)
|
|
403
|
+
|
|
404
|
+
# Step 2: Apply pruning based on importance
|
|
405
|
+
from distil_trainer.pruning import WidthPruner
|
|
406
|
+
|
|
407
|
+
pruner = WidthPruner(model, importance_scores)
|
|
408
|
+
pruned_model = pruner.prune(target_hidden_size=2048)
|
|
409
|
+
|
|
410
|
+
# Step 3: Distill knowledge from teacher
|
|
411
|
+
config = DistilTrainerConfig(
|
|
412
|
+
teacher_model=original_model,
|
|
413
|
+
student_model=pruned_model,
|
|
414
|
+
distillation_config=DistillationConfig(
|
|
415
|
+
loss_type="combined",
|
|
416
|
+
logit_loss_weight=1.0,
|
|
417
|
+
embedding_loss_weight=0.5,
|
|
418
|
+
),
|
|
419
|
+
)
|
|
420
|
+
trainer = DistilTrainer(config)
|
|
421
|
+
trainer.train()
|
|
422
|
+
```
|
|
423
|
+
|
|
424
|
+
## License
|
|
425
|
+
|
|
426
|
+
MIT License - see [LICENSE](LICENSE) for details.
|
|
427
|
+
|
|
428
|
+
## Citation
|
|
429
|
+
|
|
430
|
+
If you use this library in your research, please cite:
|
|
431
|
+
|
|
432
|
+
```bibtex
|
|
433
|
+
@software{distil_trainer,
|
|
434
|
+
title = {Distil Trainer: A Comprehensive Knowledge Distillation Framework},
|
|
435
|
+
author = {Ali Bayram},
|
|
436
|
+
year = {2025},
|
|
437
|
+
url = {https://github.com/malibayram/distil-trainer}
|
|
438
|
+
}
|
|
439
|
+
```
|
|
440
|
+
|
|
441
|
+
## Contributing
|
|
442
|
+
|
|
443
|
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
distil_trainer/__init__.py,sha256=26Hc_zue5KAxVNRgw8EN7lEzyNc8BG7d9HOXoas5tAs,703
|
|
2
|
+
distil_trainer/core/__init__.py,sha256=fzZ-V1o8-3WLum8nmV53LE3zcHLcptHUEZql6tmNIps,517
|
|
3
|
+
distil_trainer/core/callbacks.py,sha256=VgG_t-Nb2JcRZ9MuqCCh6ruZW75CM2EWQVnLOBVpqrU,6077
|
|
4
|
+
distil_trainer/core/config.py,sha256=Fd-HbH0XYf6z5WYQAgEGfxov4KfpsmBj2WpBBxTZT8s,10150
|
|
5
|
+
distil_trainer/core/trainer.py,sha256=K2K3F_R3kGjRgBw-rvdBOVX5wC4nH1BFplAOu8BSwnc,31958
|
|
6
|
+
distil_trainer/data/__init__.py,sha256=Kk1lPtU_k7wHdgi5K3H8DqrJSV5xop1yImz8t6ax_S4,542
|
|
7
|
+
distil_trainer/data/collators.py,sha256=SI6dyhFbPNc-S0iTyA15PVCJ6nC3tPT9lEmK8jiTPVg,7638
|
|
8
|
+
distil_trainer/data/datamodule.py,sha256=YGUUTeaah0C3Wb-_VauhboqiXTHagkqacT2na_TJJWg,6527
|
|
9
|
+
distil_trainer/data/datasets.py,sha256=6dnnVmoU1Ryeuryj3v0VcFqQ-d2PwnRFGzVZYlkIhNA,7513
|
|
10
|
+
distil_trainer/data/loaders.py,sha256=aS7EnfNLi2UttOaOcd82JWuaruMLUqJ5-L532gsredc,4895
|
|
11
|
+
distil_trainer/distillation/__init__.py,sha256=CKFYFXenq0S0YQIKTFSmd7if0DoeQwnkA6OE2LqHjfw,555
|
|
12
|
+
distil_trainer/distillation/losses.py,sha256=9wAkd7CFwLkBN261HXB2Bze0HekwO67c9ahHaUWaMwc,13088
|
|
13
|
+
distil_trainer/distillation/multilingual.py,sha256=6EXMxOablCRvG1UTUWy81sF-rAmkqxJ5vectAoJCJ08,9586
|
|
14
|
+
distil_trainer/distillation/strategies.py,sha256=KOY-uDmMdVY7DzwVgFAu1PcfJ6ioKIkY3TsmVnOWKhw,6706
|
|
15
|
+
distil_trainer/evaluation/__init__.py,sha256=XOA8vuqL5NmHhfJwqYiYIA0B5ZAa3Uf77petvKLlLWU,515
|
|
16
|
+
distil_trainer/evaluation/benchmarks.py,sha256=nvnTly4fUJ8yBtR-dSto9WJwg0ghG1goT48rX0iYMsE,2424
|
|
17
|
+
distil_trainer/evaluation/evaluators.py,sha256=2irpyFVR5AZt7SP1K3IaJasF_ghISY0A9tRALoqtjeg,10477
|
|
18
|
+
distil_trainer/evaluation/metrics.py,sha256=U_zEKqIjH5RBIuL8ebreMbu391neZc7nx3LDQUM9fUc,2521
|
|
19
|
+
distil_trainer/models/__init__.py,sha256=TLF8emf_0hSFAqXobgi4t6NtVeAiLm90oyX4767KKSE,129
|
|
20
|
+
distil_trainer/models/layers.py,sha256=k-iH1B8Y3CWhUeO5ELx8628R4RTEtB9jIn3hWALbKak,3429
|
|
21
|
+
distil_trainer/pruning/__init__.py,sha256=nXB-5Dk-Dg8T1VxEWgLVeXD5aKO7NpU0I87iDZm4dT8,402
|
|
22
|
+
distil_trainer/pruning/combined_pruning.py,sha256=Bw1P9fhwHrS838cLZqpUFOjgNYabVzwvrnQJDOrVrX8,4407
|
|
23
|
+
distil_trainer/pruning/depth_pruning.py,sha256=LjqzqBhNdkExlRoEf7cAa1Y2qKDrIhXnC3g8VqmXtSc,10169
|
|
24
|
+
distil_trainer/pruning/importance.py,sha256=TVhg7CCoWEglZFNWJvI_oLZFe46-m67qPXIZO-9C9T8,12585
|
|
25
|
+
distil_trainer/pruning/width_pruning.py,sha256=L6oU1DJchq54JNtIBd82QlXDRKhCSwrEm-XtqWmvkZE,19117
|
|
26
|
+
distil_trainer-0.1.10.dist-info/METADATA,sha256=d-Jk1hKEKl7XRB3Om4O-VaJIaqn0CFq7_HWKJfjgxXw,15305
|
|
27
|
+
distil_trainer-0.1.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
28
|
+
distil_trainer-0.1.10.dist-info/licenses/LICENSE,sha256=bhHl_plGK0i19Fz-EJIr8MhSQpS3RqjEJX2VWCFRXLA,1067
|
|
29
|
+
distil_trainer-0.1.10.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Ali Bayram
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|