distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,31 @@
1
+ """
2
+ Distil Trainer - A comprehensive knowledge distillation training framework.
3
+
4
+ This package provides tools for:
5
+ - Classical embedding distillation
6
+ - Model pruning (depth and width)
7
+ - Multilingual model extension
8
+ - LLM to embedding model conversion
9
+ - Reasoning/Chain-of-thought distillation
10
+ """
11
+
12
+ from distil_trainer.core.config import (
13
+ DistilTrainerConfig,
14
+ DistillationConfig,
15
+ TrainingConfig,
16
+ )
17
+ from distil_trainer.core.trainer import DistilTrainer
18
+
19
+ __version__ = "0.1.10"
20
+ __author__ = "Ali Bayram"
21
+ __email__ = "malibayram@gmail.com"
22
+
23
+ __all__ = [
24
+ # Core
25
+ "DistilTrainer",
26
+ "DistilTrainerConfig",
27
+ "DistillationConfig",
28
+ "TrainingConfig",
29
+ # Version
30
+ "__version__",
31
+ ]
@@ -0,0 +1,23 @@
1
+ """Core module for distillation training."""
2
+
3
+ from distil_trainer.core.config import (
4
+ DistilTrainerConfig,
5
+ DistillationConfig,
6
+ TrainingConfig,
7
+ )
8
+ from distil_trainer.core.trainer import DistilTrainer
9
+ from distil_trainer.core.callbacks import (
10
+ DistillationCallback,
11
+ EvaluationCallback,
12
+ LoggingCallback,
13
+ )
14
+
15
+ __all__ = [
16
+ "DistilTrainer",
17
+ "DistilTrainerConfig",
18
+ "DistillationConfig",
19
+ "TrainingConfig",
20
+ "DistillationCallback",
21
+ "EvaluationCallback",
22
+ "LoggingCallback",
23
+ ]
@@ -0,0 +1,188 @@
1
+ """Training callbacks for distillation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class DistillationCallback(ABC):
13
+ """Base class for training callbacks."""
14
+
15
+ @abstractmethod
16
+ def on_train_begin(self, trainer: Any) -> None:
17
+ """Called at the beginning of training."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def on_train_end(self, trainer: Any) -> None:
22
+ """Called at the end of training."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
27
+ """Called at the beginning of each epoch."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
32
+ """Called at the end of each epoch."""
33
+ pass
34
+
35
+ @abstractmethod
36
+ def on_step_begin(self, trainer: Any, step: int) -> None:
37
+ """Called at the beginning of each training step."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
42
+ """Called at the end of each training step."""
43
+ pass
44
+
45
+
46
+ class LoggingCallback(DistillationCallback):
47
+ """Callback for logging training progress."""
48
+
49
+ def __init__(self, log_every_n_steps: int = 100):
50
+ self.log_every_n_steps = log_every_n_steps
51
+
52
+ def on_train_begin(self, trainer: Any) -> None:
53
+ logger.info("Training started")
54
+
55
+ def on_train_end(self, trainer: Any) -> None:
56
+ logger.info("Training completed")
57
+
58
+ def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
59
+ logger.info(f"Epoch {epoch + 1} started")
60
+
61
+ def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
62
+ logger.info(f"Epoch {epoch + 1} completed: {metrics}")
63
+
64
+ def on_step_begin(self, trainer: Any, step: int) -> None:
65
+ pass
66
+
67
+ def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
68
+ if step % self.log_every_n_steps == 0:
69
+ logger.info(f"Step {step}: loss = {loss:.4f}")
70
+
71
+
72
+ class EvaluationCallback(DistillationCallback):
73
+ """Callback for running evaluation during training."""
74
+
75
+ def __init__(self, eval_every_n_steps: int = 500):
76
+ self.eval_every_n_steps = eval_every_n_steps
77
+
78
+ def on_train_begin(self, trainer: Any) -> None:
79
+ pass
80
+
81
+ def on_train_end(self, trainer: Any) -> None:
82
+ pass
83
+
84
+ def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
85
+ pass
86
+
87
+ def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
88
+ eval_metrics = trainer.evaluate()
89
+ logger.info(f"End of epoch {epoch + 1} evaluation: {eval_metrics}")
90
+
91
+ def on_step_begin(self, trainer: Any, step: int) -> None:
92
+ pass
93
+
94
+ def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
95
+ if step % self.eval_every_n_steps == 0:
96
+ eval_metrics = trainer.evaluate()
97
+ logger.info(f"Step {step} evaluation: {eval_metrics}")
98
+
99
+
100
+ class CheckpointCallback(DistillationCallback):
101
+ """Callback for saving checkpoints during training."""
102
+
103
+ def __init__(self, save_every_n_steps: int = 500, save_total_limit: int = 3):
104
+ self.save_every_n_steps = save_every_n_steps
105
+ self.save_total_limit = save_total_limit
106
+ self.saved_checkpoints: list[str] = []
107
+
108
+ def on_train_begin(self, trainer: Any) -> None:
109
+ pass
110
+
111
+ def on_train_end(self, trainer: Any) -> None:
112
+ trainer.save_model()
113
+
114
+ def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
115
+ pass
116
+
117
+ def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
118
+ pass
119
+
120
+ def on_step_begin(self, trainer: Any, step: int) -> None:
121
+ pass
122
+
123
+ def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
124
+ if step % self.save_every_n_steps == 0:
125
+ checkpoint_name = f"checkpoint-{step}"
126
+ trainer._save_checkpoint(checkpoint_name)
127
+ self.saved_checkpoints.append(checkpoint_name)
128
+
129
+ # Remove old checkpoints
130
+ while len(self.saved_checkpoints) > self.save_total_limit:
131
+ old_checkpoint = self.saved_checkpoints.pop(0)
132
+ # Could add logic to delete the old checkpoint directory
133
+
134
+
135
+ class EarlyStoppingCallback(DistillationCallback):
136
+ """Callback for early stopping based on evaluation metrics."""
137
+
138
+ def __init__(
139
+ self,
140
+ metric_name: str = "eval_loss",
141
+ patience: int = 3,
142
+ min_delta: float = 0.0,
143
+ mode: str = "min",
144
+ ):
145
+ self.metric_name = metric_name
146
+ self.patience = patience
147
+ self.min_delta = min_delta
148
+ self.mode = mode
149
+ self.best_value = float("inf") if mode == "min" else float("-inf")
150
+ self.counter = 0
151
+ self.should_stop = False
152
+
153
+ def on_train_begin(self, trainer: Any) -> None:
154
+ self.best_value = float("inf") if self.mode == "min" else float("-inf")
155
+ self.counter = 0
156
+ self.should_stop = False
157
+
158
+ def on_train_end(self, trainer: Any) -> None:
159
+ pass
160
+
161
+ def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
162
+ pass
163
+
164
+ def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
165
+ if self.metric_name not in metrics:
166
+ return
167
+
168
+ current_value = metrics[self.metric_name]
169
+
170
+ if self.mode == "min":
171
+ is_improvement = current_value < self.best_value - self.min_delta
172
+ else:
173
+ is_improvement = current_value > self.best_value + self.min_delta
174
+
175
+ if is_improvement:
176
+ self.best_value = current_value
177
+ self.counter = 0
178
+ else:
179
+ self.counter += 1
180
+ if self.counter >= self.patience:
181
+ self.should_stop = True
182
+ logger.info(f"Early stopping triggered after {epoch + 1} epochs")
183
+
184
+ def on_step_begin(self, trainer: Any, step: int) -> None:
185
+ pass
186
+
187
+ def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
188
+ pass
@@ -0,0 +1,358 @@
1
+ """Configuration dataclasses for distillation training."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal
7
+
8
+ from transformers import PreTrainedModel, PreTrainedTokenizer
9
+
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+
13
+ @dataclass
14
+ class TrainingConfig:
15
+ """Training hyperparameters configuration."""
16
+
17
+ # Basic Training
18
+ num_train_epochs: int = 1
19
+ max_steps: int = -1 # -1 means use epochs
20
+ per_device_train_batch_size: int = 64
21
+ per_device_eval_batch_size: int = 64
22
+ gradient_accumulation_steps: int = 1
23
+
24
+ # Learning Rate
25
+ learning_rate: float = 1e-4
26
+ min_learning_rate: float = 1e-5
27
+ weight_decay: float = 0.01
28
+
29
+ # Scheduler
30
+ lr_scheduler_type: Literal[
31
+ "linear",
32
+ "cosine",
33
+ "cosine_with_restarts",
34
+ "polynomial",
35
+ "constant",
36
+ "constant_with_warmup",
37
+ ] = "cosine"
38
+ warmup_ratio: float = 0.1
39
+ warmup_steps: int = 0
40
+
41
+ # Optimization
42
+ optimizer: Literal["adamw", "adam", "sgd", "adafactor"] = "adamw"
43
+ adam_beta1: float = 0.9
44
+ adam_beta2: float = 0.999
45
+ adam_epsilon: float = 1e-8
46
+ max_grad_norm: float = 1.0
47
+
48
+ # Logging & Evaluation
49
+ logging_steps: int = 100
50
+ eval_strategy: Literal["steps", "epoch", "no"] = "steps"
51
+ eval_steps: int = 500
52
+
53
+ # Checkpointing
54
+ save_steps: int = 500
55
+ save_total_limit: int = 2
56
+ load_best_model_at_end: bool = True
57
+ metric_for_best_model: str = "eval_loss"
58
+ greater_is_better: bool = False
59
+
60
+ # Tracking
61
+ run_name: str | None = None
62
+ report_to: list[str] = field(default_factory=lambda: ["tensorboard"])
63
+
64
+
65
+ @dataclass
66
+ class DistillationConfig:
67
+ """Configuration for distillation losses and strategies."""
68
+
69
+ # Loss Type
70
+ loss_type: Literal[
71
+ "mse", # Mean Squared Error on embeddings
72
+ "kl_divergence", # KL divergence on logits
73
+ "cosine", # Cosine similarity loss
74
+ "ranking", # Ranking loss for embedding models
75
+ "combined", # Combination of multiple losses
76
+ ] = "mse"
77
+
78
+ # Loss Weights (for combined loss)
79
+ logit_loss_weight: float = 1.0
80
+ embedding_loss_weight: float = 1.0
81
+ intermediate_loss_weight: float = 0.0
82
+ attention_loss_weight: float = 0.0
83
+
84
+ # Temperature for KL divergence
85
+ temperature: float = 1.0
86
+
87
+ # Embedding Distillation Options
88
+ use_pca_projection: bool = True # When student dim < teacher dim
89
+ pca_num_samples: int = 20000
90
+
91
+ # Intermediate Layer Mapping
92
+ layer_mapping: dict[int, int] | None = None
93
+
94
+ # Teacher Inference Settings
95
+ teacher_inference_batch_size: int = 128
96
+ precompute_teacher_embeddings: bool = True
97
+ cache_teacher_embeddings: bool = True
98
+ teacher_embeddings_cache_dir: str | None = None
99
+
100
+ # Ranking Loss Options
101
+ in_batch_negatives: bool = True
102
+ hard_negatives_per_sample: int = 5
103
+ margin: float = 0.5
104
+
105
+
106
+ @dataclass
107
+ class DataConfig:
108
+ """Configuration for data loading and preprocessing."""
109
+
110
+ # Dataset paths or names
111
+ train_data: str | None = None
112
+ eval_data: str | None = None
113
+ test_data: str | None = None
114
+
115
+ # Dataset options
116
+ dataset_name: str | None = None
117
+ dataset_config: str | None = None
118
+ text_column: str = "sentence"
119
+ max_samples: int | None = None
120
+
121
+ # Preprocessing
122
+ max_seq_length: int = 512
123
+ num_workers: int = 4
124
+ remove_columns: list[str] | None = None
125
+
126
+ # Data format
127
+ data_format: Literal["single", "pair", "triplet"] = "single"
128
+
129
+
130
+ @dataclass
131
+ class PruningConfig:
132
+ """Base configuration for pruning."""
133
+
134
+ # Pruning method
135
+ method: Literal["depth", "width", "combined"] = "depth"
136
+
137
+ # Importance estimation
138
+ importance_method: Literal[
139
+ "activation", "gradient", "taylor", "wanda", "cosine_similarity"
140
+ ] = "activation"
141
+ calibration_samples: int = 1024
142
+
143
+
144
+ @dataclass
145
+ class LayerReductionConfig(PruningConfig):
146
+ """Configuration for layer reduction (depth pruning)."""
147
+
148
+ method: Literal["depth", "width", "combined"] = "depth"
149
+
150
+ # Layers to keep (0-indexed)
151
+ layers_to_keep: list[int] | None = None
152
+
153
+ # Alternative: specify how many layers to keep
154
+ num_layers_to_keep: int | None = None
155
+
156
+ # Alternative: specify layers to drop
157
+ layers_to_drop: list[int] | None = None
158
+
159
+ # Layer selection strategy
160
+ layer_selection: Literal[
161
+ "first", "last", "even", "importance", "custom"
162
+ ] = "importance"
163
+
164
+
165
+ @dataclass
166
+ class WidthPruningConfig(PruningConfig):
167
+ """Configuration for width-based pruning."""
168
+
169
+ method: Literal["depth", "width", "combined"] = "width"
170
+
171
+ # Target dimensions (set to None to keep original)
172
+ target_hidden_size: int | None = None
173
+ target_intermediate_size: int | None = None
174
+ target_num_attention_heads: int | None = None
175
+ target_num_key_value_heads: int | None = None
176
+
177
+ # Alternative: specify reduction ratios
178
+ hidden_size_ratio: float | None = None
179
+ intermediate_size_ratio: float | None = None
180
+ attention_head_ratio: float | None = None
181
+
182
+
183
+ @dataclass
184
+ class CombinedPruningConfig(PruningConfig):
185
+ """Configuration for combined depth and width pruning."""
186
+
187
+ method: Literal["depth", "width", "combined"] = "combined"
188
+
189
+ # Target model size (parameters)
190
+ target_params: int | None = None
191
+
192
+ # Individual configs
193
+ depth_config: LayerReductionConfig | None = None
194
+ width_config: WidthPruningConfig | None = None
195
+
196
+ # Pruning order
197
+ pruning_order: Literal["depth_first", "width_first", "interleaved"] = "depth_first"
198
+
199
+ # Iterative pruning
200
+ num_iterations: int = 1
201
+ prune_ratio_per_iteration: float = 0.5
202
+
203
+
204
+ @dataclass
205
+ class MultilingualConfig:
206
+ """Configuration for multilingual knowledge distillation."""
207
+
208
+ # Source languages (teacher understands these)
209
+ source_languages: list[str] = field(default_factory=lambda: ["en"])
210
+
211
+ # Target languages (student should learn these)
212
+ target_languages: list[str] = field(default_factory=list)
213
+
214
+ # Parallel sentence datasets
215
+ parallel_datasets: list[str] = field(
216
+ default_factory=lambda: [
217
+ "sentence-transformers/parallel-sentences-talks",
218
+ "sentence-transformers/parallel-sentences-tatoeba",
219
+ ]
220
+ )
221
+
222
+ # Maximum sentences per language pair
223
+ max_sentences_per_language: int = 500000
224
+
225
+ # Student model configuration
226
+ student_model: str = "xlm-roberta-base"
227
+ student_max_seq_length: int = 128
228
+
229
+ # Training settings
230
+ num_train_epochs: int = 5
231
+ evaluation_steps: int = 5000
232
+
233
+
234
+ @dataclass
235
+ class EmbeddingConversionConfig:
236
+ """Configuration for converting LLM to embedding model."""
237
+
238
+ # Source LLM
239
+ source_model: str = ""
240
+
241
+ # Attention modification
242
+ convert_to_bidirectional: bool = True
243
+
244
+ # Pooling strategy
245
+ pooling_mode: Literal["mean", "cls", "last_token", "weighted_mean"] = "mean"
246
+
247
+ # Training data format
248
+ data_format: Literal["triplet", "pair", "single"] = "triplet"
249
+
250
+ # Prefix configuration
251
+ query_prefix: str = "query: "
252
+ passage_prefix: str = "passage: "
253
+
254
+ # Loss function
255
+ loss_type: Literal["ranking", "contrastive", "cosine"] = "ranking"
256
+ in_batch_negatives: bool = True
257
+ hard_negatives_count: int = 5
258
+
259
+
260
+ @dataclass
261
+ class ReasoningDistillationConfig:
262
+ """Configuration for reasoning/CoT distillation."""
263
+
264
+ # Teacher model (reasoning model)
265
+ teacher_model: str = ""
266
+
267
+ # Student model
268
+ student_model: str = ""
269
+
270
+ # Data generation
271
+ generate_reasoning_data: bool = True
272
+ reasoning_api: str | None = None
273
+ num_reasoning_samples: int = 10000
274
+
275
+ # Reasoning format
276
+ reasoning_format: Literal["cot", "step_by_step", "scratchpad"] = "cot"
277
+
278
+ # Special tokens
279
+ thought_start_token: str = "<|begin_of_thought|>"
280
+ thought_end_token: str = "<|end_of_thought|>"
281
+ solution_start_token: str = "<|begin_of_solution|>"
282
+ solution_end_token: str = "<|end_of_solution|>"
283
+
284
+ # Training settings
285
+ max_seq_length: int = 16384
286
+ mask_reasoning: bool = False
287
+
288
+
289
+ @dataclass
290
+ class WandbConfig:
291
+ """Configuration for Weights & Biases logging."""
292
+ project: str = "distil-trainer"
293
+ entity: str | None = None
294
+ name: str | None = None
295
+ tags: list[str] = field(default_factory=list)
296
+ group: str | None = None
297
+ notes: str | None = None
298
+
299
+
300
+ @dataclass
301
+ class HubConfig:
302
+ """Configuration for HuggingFace Hub integration."""
303
+ push_to_hub: bool = False
304
+ hub_model_id: str | None = None
305
+ hub_token: str | None = None
306
+ hub_private_repo: bool = False
307
+ push_to_hub_interval: Literal["every_save", "end"] = "end"
308
+
309
+
310
+ @dataclass
311
+ class DistilTrainerConfig:
312
+ """Main configuration for distillation training."""
313
+
314
+ # Model Configuration
315
+ teacher_model: str | SentenceTransformer | PreTrainedModel = ""
316
+ student_model: str | SentenceTransformer | PreTrainedModel | None = None
317
+ student_model_name: str | None = None
318
+
319
+ # Student Initialization Strategy
320
+ student_init_strategy: Literal[
321
+ "from_pretrained",
322
+ "layer_reduction",
323
+ "width_pruning",
324
+ "depth_pruning",
325
+ "combined_pruning",
326
+ ] = "from_pretrained"
327
+
328
+ # Pruning Configuration (if applicable)
329
+ pruning_config: PruningConfig | LayerReductionConfig | WidthPruningConfig | CombinedPruningConfig | None = None
330
+
331
+ # Distillation Configuration
332
+ distillation_config: DistillationConfig = field(default_factory=DistillationConfig)
333
+
334
+ # Training Configuration
335
+ training_config: TrainingConfig = field(default_factory=TrainingConfig)
336
+
337
+ # Data Configuration
338
+ data_config: DataConfig = field(default_factory=DataConfig)
339
+
340
+ # Output Configuration
341
+ output_dir: str = "./distilled_model"
342
+ save_strategy: Literal["steps", "epoch", "best"] = "best"
343
+ save_total_limit: int = 2
344
+
345
+ # Hardware Configuration
346
+ device: str = "auto"
347
+ precision: Literal["fp32", "fp16", "bf16", "int8"] = "bf16"
348
+ distributed: bool = False
349
+ tensor_parallel_size: int = 1
350
+ pipeline_parallel_size: int = 1
351
+
352
+ # Logging
353
+ logging_dir: str | None = None
354
+ seed: int = 42
355
+
356
+ # Integrations
357
+ wandb_config: WandbConfig = field(default_factory=WandbConfig)
358
+ hub_config: HubConfig = field(default_factory=HubConfig)