distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Distil Trainer - A comprehensive knowledge distillation training framework.
|
|
3
|
+
|
|
4
|
+
This package provides tools for:
|
|
5
|
+
- Classical embedding distillation
|
|
6
|
+
- Model pruning (depth and width)
|
|
7
|
+
- Multilingual model extension
|
|
8
|
+
- LLM to embedding model conversion
|
|
9
|
+
- Reasoning/Chain-of-thought distillation
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from distil_trainer.core.config import (
|
|
13
|
+
DistilTrainerConfig,
|
|
14
|
+
DistillationConfig,
|
|
15
|
+
TrainingConfig,
|
|
16
|
+
)
|
|
17
|
+
from distil_trainer.core.trainer import DistilTrainer
|
|
18
|
+
|
|
19
|
+
__version__ = "0.1.10"
|
|
20
|
+
__author__ = "Ali Bayram"
|
|
21
|
+
__email__ = "malibayram@gmail.com"
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
# Core
|
|
25
|
+
"DistilTrainer",
|
|
26
|
+
"DistilTrainerConfig",
|
|
27
|
+
"DistillationConfig",
|
|
28
|
+
"TrainingConfig",
|
|
29
|
+
# Version
|
|
30
|
+
"__version__",
|
|
31
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Core module for distillation training."""
|
|
2
|
+
|
|
3
|
+
from distil_trainer.core.config import (
|
|
4
|
+
DistilTrainerConfig,
|
|
5
|
+
DistillationConfig,
|
|
6
|
+
TrainingConfig,
|
|
7
|
+
)
|
|
8
|
+
from distil_trainer.core.trainer import DistilTrainer
|
|
9
|
+
from distil_trainer.core.callbacks import (
|
|
10
|
+
DistillationCallback,
|
|
11
|
+
EvaluationCallback,
|
|
12
|
+
LoggingCallback,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"DistilTrainer",
|
|
17
|
+
"DistilTrainerConfig",
|
|
18
|
+
"DistillationConfig",
|
|
19
|
+
"TrainingConfig",
|
|
20
|
+
"DistillationCallback",
|
|
21
|
+
"EvaluationCallback",
|
|
22
|
+
"LoggingCallback",
|
|
23
|
+
]
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Training callbacks for distillation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DistillationCallback(ABC):
|
|
13
|
+
"""Base class for training callbacks."""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def on_train_begin(self, trainer: Any) -> None:
|
|
17
|
+
"""Called at the beginning of training."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def on_train_end(self, trainer: Any) -> None:
|
|
22
|
+
"""Called at the end of training."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
|
|
27
|
+
"""Called at the beginning of each epoch."""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
|
|
32
|
+
"""Called at the end of each epoch."""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def on_step_begin(self, trainer: Any, step: int) -> None:
|
|
37
|
+
"""Called at the beginning of each training step."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
|
|
42
|
+
"""Called at the end of each training step."""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LoggingCallback(DistillationCallback):
|
|
47
|
+
"""Callback for logging training progress."""
|
|
48
|
+
|
|
49
|
+
def __init__(self, log_every_n_steps: int = 100):
|
|
50
|
+
self.log_every_n_steps = log_every_n_steps
|
|
51
|
+
|
|
52
|
+
def on_train_begin(self, trainer: Any) -> None:
|
|
53
|
+
logger.info("Training started")
|
|
54
|
+
|
|
55
|
+
def on_train_end(self, trainer: Any) -> None:
|
|
56
|
+
logger.info("Training completed")
|
|
57
|
+
|
|
58
|
+
def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
|
|
59
|
+
logger.info(f"Epoch {epoch + 1} started")
|
|
60
|
+
|
|
61
|
+
def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
|
|
62
|
+
logger.info(f"Epoch {epoch + 1} completed: {metrics}")
|
|
63
|
+
|
|
64
|
+
def on_step_begin(self, trainer: Any, step: int) -> None:
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
|
|
68
|
+
if step % self.log_every_n_steps == 0:
|
|
69
|
+
logger.info(f"Step {step}: loss = {loss:.4f}")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class EvaluationCallback(DistillationCallback):
|
|
73
|
+
"""Callback for running evaluation during training."""
|
|
74
|
+
|
|
75
|
+
def __init__(self, eval_every_n_steps: int = 500):
|
|
76
|
+
self.eval_every_n_steps = eval_every_n_steps
|
|
77
|
+
|
|
78
|
+
def on_train_begin(self, trainer: Any) -> None:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def on_train_end(self, trainer: Any) -> None:
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
|
|
88
|
+
eval_metrics = trainer.evaluate()
|
|
89
|
+
logger.info(f"End of epoch {epoch + 1} evaluation: {eval_metrics}")
|
|
90
|
+
|
|
91
|
+
def on_step_begin(self, trainer: Any, step: int) -> None:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
|
|
95
|
+
if step % self.eval_every_n_steps == 0:
|
|
96
|
+
eval_metrics = trainer.evaluate()
|
|
97
|
+
logger.info(f"Step {step} evaluation: {eval_metrics}")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class CheckpointCallback(DistillationCallback):
|
|
101
|
+
"""Callback for saving checkpoints during training."""
|
|
102
|
+
|
|
103
|
+
def __init__(self, save_every_n_steps: int = 500, save_total_limit: int = 3):
|
|
104
|
+
self.save_every_n_steps = save_every_n_steps
|
|
105
|
+
self.save_total_limit = save_total_limit
|
|
106
|
+
self.saved_checkpoints: list[str] = []
|
|
107
|
+
|
|
108
|
+
def on_train_begin(self, trainer: Any) -> None:
|
|
109
|
+
pass
|
|
110
|
+
|
|
111
|
+
def on_train_end(self, trainer: Any) -> None:
|
|
112
|
+
trainer.save_model()
|
|
113
|
+
|
|
114
|
+
def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
|
|
118
|
+
pass
|
|
119
|
+
|
|
120
|
+
def on_step_begin(self, trainer: Any, step: int) -> None:
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
|
|
124
|
+
if step % self.save_every_n_steps == 0:
|
|
125
|
+
checkpoint_name = f"checkpoint-{step}"
|
|
126
|
+
trainer._save_checkpoint(checkpoint_name)
|
|
127
|
+
self.saved_checkpoints.append(checkpoint_name)
|
|
128
|
+
|
|
129
|
+
# Remove old checkpoints
|
|
130
|
+
while len(self.saved_checkpoints) > self.save_total_limit:
|
|
131
|
+
old_checkpoint = self.saved_checkpoints.pop(0)
|
|
132
|
+
# Could add logic to delete the old checkpoint directory
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class EarlyStoppingCallback(DistillationCallback):
|
|
136
|
+
"""Callback for early stopping based on evaluation metrics."""
|
|
137
|
+
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
metric_name: str = "eval_loss",
|
|
141
|
+
patience: int = 3,
|
|
142
|
+
min_delta: float = 0.0,
|
|
143
|
+
mode: str = "min",
|
|
144
|
+
):
|
|
145
|
+
self.metric_name = metric_name
|
|
146
|
+
self.patience = patience
|
|
147
|
+
self.min_delta = min_delta
|
|
148
|
+
self.mode = mode
|
|
149
|
+
self.best_value = float("inf") if mode == "min" else float("-inf")
|
|
150
|
+
self.counter = 0
|
|
151
|
+
self.should_stop = False
|
|
152
|
+
|
|
153
|
+
def on_train_begin(self, trainer: Any) -> None:
|
|
154
|
+
self.best_value = float("inf") if self.mode == "min" else float("-inf")
|
|
155
|
+
self.counter = 0
|
|
156
|
+
self.should_stop = False
|
|
157
|
+
|
|
158
|
+
def on_train_end(self, trainer: Any) -> None:
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
def on_epoch_begin(self, trainer: Any, epoch: int) -> None:
|
|
162
|
+
pass
|
|
163
|
+
|
|
164
|
+
def on_epoch_end(self, trainer: Any, epoch: int, metrics: dict[str, float]) -> None:
|
|
165
|
+
if self.metric_name not in metrics:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
current_value = metrics[self.metric_name]
|
|
169
|
+
|
|
170
|
+
if self.mode == "min":
|
|
171
|
+
is_improvement = current_value < self.best_value - self.min_delta
|
|
172
|
+
else:
|
|
173
|
+
is_improvement = current_value > self.best_value + self.min_delta
|
|
174
|
+
|
|
175
|
+
if is_improvement:
|
|
176
|
+
self.best_value = current_value
|
|
177
|
+
self.counter = 0
|
|
178
|
+
else:
|
|
179
|
+
self.counter += 1
|
|
180
|
+
if self.counter >= self.patience:
|
|
181
|
+
self.should_stop = True
|
|
182
|
+
logger.info(f"Early stopping triggered after {epoch + 1} epochs")
|
|
183
|
+
|
|
184
|
+
def on_step_begin(self, trainer: Any, step: int) -> None:
|
|
185
|
+
pass
|
|
186
|
+
|
|
187
|
+
def on_step_end(self, trainer: Any, step: int, loss: float) -> None:
|
|
188
|
+
pass
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Configuration dataclasses for distillation training."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Literal
|
|
7
|
+
|
|
8
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
9
|
+
|
|
10
|
+
from sentence_transformers import SentenceTransformer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class TrainingConfig:
|
|
15
|
+
"""Training hyperparameters configuration."""
|
|
16
|
+
|
|
17
|
+
# Basic Training
|
|
18
|
+
num_train_epochs: int = 1
|
|
19
|
+
max_steps: int = -1 # -1 means use epochs
|
|
20
|
+
per_device_train_batch_size: int = 64
|
|
21
|
+
per_device_eval_batch_size: int = 64
|
|
22
|
+
gradient_accumulation_steps: int = 1
|
|
23
|
+
|
|
24
|
+
# Learning Rate
|
|
25
|
+
learning_rate: float = 1e-4
|
|
26
|
+
min_learning_rate: float = 1e-5
|
|
27
|
+
weight_decay: float = 0.01
|
|
28
|
+
|
|
29
|
+
# Scheduler
|
|
30
|
+
lr_scheduler_type: Literal[
|
|
31
|
+
"linear",
|
|
32
|
+
"cosine",
|
|
33
|
+
"cosine_with_restarts",
|
|
34
|
+
"polynomial",
|
|
35
|
+
"constant",
|
|
36
|
+
"constant_with_warmup",
|
|
37
|
+
] = "cosine"
|
|
38
|
+
warmup_ratio: float = 0.1
|
|
39
|
+
warmup_steps: int = 0
|
|
40
|
+
|
|
41
|
+
# Optimization
|
|
42
|
+
optimizer: Literal["adamw", "adam", "sgd", "adafactor"] = "adamw"
|
|
43
|
+
adam_beta1: float = 0.9
|
|
44
|
+
adam_beta2: float = 0.999
|
|
45
|
+
adam_epsilon: float = 1e-8
|
|
46
|
+
max_grad_norm: float = 1.0
|
|
47
|
+
|
|
48
|
+
# Logging & Evaluation
|
|
49
|
+
logging_steps: int = 100
|
|
50
|
+
eval_strategy: Literal["steps", "epoch", "no"] = "steps"
|
|
51
|
+
eval_steps: int = 500
|
|
52
|
+
|
|
53
|
+
# Checkpointing
|
|
54
|
+
save_steps: int = 500
|
|
55
|
+
save_total_limit: int = 2
|
|
56
|
+
load_best_model_at_end: bool = True
|
|
57
|
+
metric_for_best_model: str = "eval_loss"
|
|
58
|
+
greater_is_better: bool = False
|
|
59
|
+
|
|
60
|
+
# Tracking
|
|
61
|
+
run_name: str | None = None
|
|
62
|
+
report_to: list[str] = field(default_factory=lambda: ["tensorboard"])
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class DistillationConfig:
|
|
67
|
+
"""Configuration for distillation losses and strategies."""
|
|
68
|
+
|
|
69
|
+
# Loss Type
|
|
70
|
+
loss_type: Literal[
|
|
71
|
+
"mse", # Mean Squared Error on embeddings
|
|
72
|
+
"kl_divergence", # KL divergence on logits
|
|
73
|
+
"cosine", # Cosine similarity loss
|
|
74
|
+
"ranking", # Ranking loss for embedding models
|
|
75
|
+
"combined", # Combination of multiple losses
|
|
76
|
+
] = "mse"
|
|
77
|
+
|
|
78
|
+
# Loss Weights (for combined loss)
|
|
79
|
+
logit_loss_weight: float = 1.0
|
|
80
|
+
embedding_loss_weight: float = 1.0
|
|
81
|
+
intermediate_loss_weight: float = 0.0
|
|
82
|
+
attention_loss_weight: float = 0.0
|
|
83
|
+
|
|
84
|
+
# Temperature for KL divergence
|
|
85
|
+
temperature: float = 1.0
|
|
86
|
+
|
|
87
|
+
# Embedding Distillation Options
|
|
88
|
+
use_pca_projection: bool = True # When student dim < teacher dim
|
|
89
|
+
pca_num_samples: int = 20000
|
|
90
|
+
|
|
91
|
+
# Intermediate Layer Mapping
|
|
92
|
+
layer_mapping: dict[int, int] | None = None
|
|
93
|
+
|
|
94
|
+
# Teacher Inference Settings
|
|
95
|
+
teacher_inference_batch_size: int = 128
|
|
96
|
+
precompute_teacher_embeddings: bool = True
|
|
97
|
+
cache_teacher_embeddings: bool = True
|
|
98
|
+
teacher_embeddings_cache_dir: str | None = None
|
|
99
|
+
|
|
100
|
+
# Ranking Loss Options
|
|
101
|
+
in_batch_negatives: bool = True
|
|
102
|
+
hard_negatives_per_sample: int = 5
|
|
103
|
+
margin: float = 0.5
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class DataConfig:
|
|
108
|
+
"""Configuration for data loading and preprocessing."""
|
|
109
|
+
|
|
110
|
+
# Dataset paths or names
|
|
111
|
+
train_data: str | None = None
|
|
112
|
+
eval_data: str | None = None
|
|
113
|
+
test_data: str | None = None
|
|
114
|
+
|
|
115
|
+
# Dataset options
|
|
116
|
+
dataset_name: str | None = None
|
|
117
|
+
dataset_config: str | None = None
|
|
118
|
+
text_column: str = "sentence"
|
|
119
|
+
max_samples: int | None = None
|
|
120
|
+
|
|
121
|
+
# Preprocessing
|
|
122
|
+
max_seq_length: int = 512
|
|
123
|
+
num_workers: int = 4
|
|
124
|
+
remove_columns: list[str] | None = None
|
|
125
|
+
|
|
126
|
+
# Data format
|
|
127
|
+
data_format: Literal["single", "pair", "triplet"] = "single"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class PruningConfig:
|
|
132
|
+
"""Base configuration for pruning."""
|
|
133
|
+
|
|
134
|
+
# Pruning method
|
|
135
|
+
method: Literal["depth", "width", "combined"] = "depth"
|
|
136
|
+
|
|
137
|
+
# Importance estimation
|
|
138
|
+
importance_method: Literal[
|
|
139
|
+
"activation", "gradient", "taylor", "wanda", "cosine_similarity"
|
|
140
|
+
] = "activation"
|
|
141
|
+
calibration_samples: int = 1024
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@dataclass
|
|
145
|
+
class LayerReductionConfig(PruningConfig):
|
|
146
|
+
"""Configuration for layer reduction (depth pruning)."""
|
|
147
|
+
|
|
148
|
+
method: Literal["depth", "width", "combined"] = "depth"
|
|
149
|
+
|
|
150
|
+
# Layers to keep (0-indexed)
|
|
151
|
+
layers_to_keep: list[int] | None = None
|
|
152
|
+
|
|
153
|
+
# Alternative: specify how many layers to keep
|
|
154
|
+
num_layers_to_keep: int | None = None
|
|
155
|
+
|
|
156
|
+
# Alternative: specify layers to drop
|
|
157
|
+
layers_to_drop: list[int] | None = None
|
|
158
|
+
|
|
159
|
+
# Layer selection strategy
|
|
160
|
+
layer_selection: Literal[
|
|
161
|
+
"first", "last", "even", "importance", "custom"
|
|
162
|
+
] = "importance"
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@dataclass
|
|
166
|
+
class WidthPruningConfig(PruningConfig):
|
|
167
|
+
"""Configuration for width-based pruning."""
|
|
168
|
+
|
|
169
|
+
method: Literal["depth", "width", "combined"] = "width"
|
|
170
|
+
|
|
171
|
+
# Target dimensions (set to None to keep original)
|
|
172
|
+
target_hidden_size: int | None = None
|
|
173
|
+
target_intermediate_size: int | None = None
|
|
174
|
+
target_num_attention_heads: int | None = None
|
|
175
|
+
target_num_key_value_heads: int | None = None
|
|
176
|
+
|
|
177
|
+
# Alternative: specify reduction ratios
|
|
178
|
+
hidden_size_ratio: float | None = None
|
|
179
|
+
intermediate_size_ratio: float | None = None
|
|
180
|
+
attention_head_ratio: float | None = None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class CombinedPruningConfig(PruningConfig):
|
|
185
|
+
"""Configuration for combined depth and width pruning."""
|
|
186
|
+
|
|
187
|
+
method: Literal["depth", "width", "combined"] = "combined"
|
|
188
|
+
|
|
189
|
+
# Target model size (parameters)
|
|
190
|
+
target_params: int | None = None
|
|
191
|
+
|
|
192
|
+
# Individual configs
|
|
193
|
+
depth_config: LayerReductionConfig | None = None
|
|
194
|
+
width_config: WidthPruningConfig | None = None
|
|
195
|
+
|
|
196
|
+
# Pruning order
|
|
197
|
+
pruning_order: Literal["depth_first", "width_first", "interleaved"] = "depth_first"
|
|
198
|
+
|
|
199
|
+
# Iterative pruning
|
|
200
|
+
num_iterations: int = 1
|
|
201
|
+
prune_ratio_per_iteration: float = 0.5
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@dataclass
|
|
205
|
+
class MultilingualConfig:
|
|
206
|
+
"""Configuration for multilingual knowledge distillation."""
|
|
207
|
+
|
|
208
|
+
# Source languages (teacher understands these)
|
|
209
|
+
source_languages: list[str] = field(default_factory=lambda: ["en"])
|
|
210
|
+
|
|
211
|
+
# Target languages (student should learn these)
|
|
212
|
+
target_languages: list[str] = field(default_factory=list)
|
|
213
|
+
|
|
214
|
+
# Parallel sentence datasets
|
|
215
|
+
parallel_datasets: list[str] = field(
|
|
216
|
+
default_factory=lambda: [
|
|
217
|
+
"sentence-transformers/parallel-sentences-talks",
|
|
218
|
+
"sentence-transformers/parallel-sentences-tatoeba",
|
|
219
|
+
]
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Maximum sentences per language pair
|
|
223
|
+
max_sentences_per_language: int = 500000
|
|
224
|
+
|
|
225
|
+
# Student model configuration
|
|
226
|
+
student_model: str = "xlm-roberta-base"
|
|
227
|
+
student_max_seq_length: int = 128
|
|
228
|
+
|
|
229
|
+
# Training settings
|
|
230
|
+
num_train_epochs: int = 5
|
|
231
|
+
evaluation_steps: int = 5000
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class EmbeddingConversionConfig:
|
|
236
|
+
"""Configuration for converting LLM to embedding model."""
|
|
237
|
+
|
|
238
|
+
# Source LLM
|
|
239
|
+
source_model: str = ""
|
|
240
|
+
|
|
241
|
+
# Attention modification
|
|
242
|
+
convert_to_bidirectional: bool = True
|
|
243
|
+
|
|
244
|
+
# Pooling strategy
|
|
245
|
+
pooling_mode: Literal["mean", "cls", "last_token", "weighted_mean"] = "mean"
|
|
246
|
+
|
|
247
|
+
# Training data format
|
|
248
|
+
data_format: Literal["triplet", "pair", "single"] = "triplet"
|
|
249
|
+
|
|
250
|
+
# Prefix configuration
|
|
251
|
+
query_prefix: str = "query: "
|
|
252
|
+
passage_prefix: str = "passage: "
|
|
253
|
+
|
|
254
|
+
# Loss function
|
|
255
|
+
loss_type: Literal["ranking", "contrastive", "cosine"] = "ranking"
|
|
256
|
+
in_batch_negatives: bool = True
|
|
257
|
+
hard_negatives_count: int = 5
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@dataclass
|
|
261
|
+
class ReasoningDistillationConfig:
|
|
262
|
+
"""Configuration for reasoning/CoT distillation."""
|
|
263
|
+
|
|
264
|
+
# Teacher model (reasoning model)
|
|
265
|
+
teacher_model: str = ""
|
|
266
|
+
|
|
267
|
+
# Student model
|
|
268
|
+
student_model: str = ""
|
|
269
|
+
|
|
270
|
+
# Data generation
|
|
271
|
+
generate_reasoning_data: bool = True
|
|
272
|
+
reasoning_api: str | None = None
|
|
273
|
+
num_reasoning_samples: int = 10000
|
|
274
|
+
|
|
275
|
+
# Reasoning format
|
|
276
|
+
reasoning_format: Literal["cot", "step_by_step", "scratchpad"] = "cot"
|
|
277
|
+
|
|
278
|
+
# Special tokens
|
|
279
|
+
thought_start_token: str = "<|begin_of_thought|>"
|
|
280
|
+
thought_end_token: str = "<|end_of_thought|>"
|
|
281
|
+
solution_start_token: str = "<|begin_of_solution|>"
|
|
282
|
+
solution_end_token: str = "<|end_of_solution|>"
|
|
283
|
+
|
|
284
|
+
# Training settings
|
|
285
|
+
max_seq_length: int = 16384
|
|
286
|
+
mask_reasoning: bool = False
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@dataclass
|
|
290
|
+
class WandbConfig:
|
|
291
|
+
"""Configuration for Weights & Biases logging."""
|
|
292
|
+
project: str = "distil-trainer"
|
|
293
|
+
entity: str | None = None
|
|
294
|
+
name: str | None = None
|
|
295
|
+
tags: list[str] = field(default_factory=list)
|
|
296
|
+
group: str | None = None
|
|
297
|
+
notes: str | None = None
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@dataclass
|
|
301
|
+
class HubConfig:
|
|
302
|
+
"""Configuration for HuggingFace Hub integration."""
|
|
303
|
+
push_to_hub: bool = False
|
|
304
|
+
hub_model_id: str | None = None
|
|
305
|
+
hub_token: str | None = None
|
|
306
|
+
hub_private_repo: bool = False
|
|
307
|
+
push_to_hub_interval: Literal["every_save", "end"] = "end"
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@dataclass
|
|
311
|
+
class DistilTrainerConfig:
|
|
312
|
+
"""Main configuration for distillation training."""
|
|
313
|
+
|
|
314
|
+
# Model Configuration
|
|
315
|
+
teacher_model: str | SentenceTransformer | PreTrainedModel = ""
|
|
316
|
+
student_model: str | SentenceTransformer | PreTrainedModel | None = None
|
|
317
|
+
student_model_name: str | None = None
|
|
318
|
+
|
|
319
|
+
# Student Initialization Strategy
|
|
320
|
+
student_init_strategy: Literal[
|
|
321
|
+
"from_pretrained",
|
|
322
|
+
"layer_reduction",
|
|
323
|
+
"width_pruning",
|
|
324
|
+
"depth_pruning",
|
|
325
|
+
"combined_pruning",
|
|
326
|
+
] = "from_pretrained"
|
|
327
|
+
|
|
328
|
+
# Pruning Configuration (if applicable)
|
|
329
|
+
pruning_config: PruningConfig | LayerReductionConfig | WidthPruningConfig | CombinedPruningConfig | None = None
|
|
330
|
+
|
|
331
|
+
# Distillation Configuration
|
|
332
|
+
distillation_config: DistillationConfig = field(default_factory=DistillationConfig)
|
|
333
|
+
|
|
334
|
+
# Training Configuration
|
|
335
|
+
training_config: TrainingConfig = field(default_factory=TrainingConfig)
|
|
336
|
+
|
|
337
|
+
# Data Configuration
|
|
338
|
+
data_config: DataConfig = field(default_factory=DataConfig)
|
|
339
|
+
|
|
340
|
+
# Output Configuration
|
|
341
|
+
output_dir: str = "./distilled_model"
|
|
342
|
+
save_strategy: Literal["steps", "epoch", "best"] = "best"
|
|
343
|
+
save_total_limit: int = 2
|
|
344
|
+
|
|
345
|
+
# Hardware Configuration
|
|
346
|
+
device: str = "auto"
|
|
347
|
+
precision: Literal["fp32", "fp16", "bf16", "int8"] = "bf16"
|
|
348
|
+
distributed: bool = False
|
|
349
|
+
tensor_parallel_size: int = 1
|
|
350
|
+
pipeline_parallel_size: int = 1
|
|
351
|
+
|
|
352
|
+
# Logging
|
|
353
|
+
logging_dir: str | None = None
|
|
354
|
+
seed: int = 42
|
|
355
|
+
|
|
356
|
+
# Integrations
|
|
357
|
+
wandb_config: WandbConfig = field(default_factory=WandbConfig)
|
|
358
|
+
hub_config: HubConfig = field(default_factory=HubConfig)
|