distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,285 @@
1
+ """Multilingual distillation for extending models to new languages."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from typing import Any
8
+
9
+ import torch
10
+ from datasets import DatasetDict, load_dataset
11
+ from tqdm import tqdm
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ from distil_trainer.distillation.losses import DistillationLosses
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class MultilingualDistillationStrategy:
22
+ """
23
+ Strategy for extending a monolingual model to multiple languages.
24
+
25
+ Uses parallel sentences to train a multilingual student where:
26
+ - Student source language embeddings match teacher source embeddings
27
+ - Student target language embeddings match teacher source embeddings
28
+
29
+ Example:
30
+ >>> strategy = MultilingualDistillationStrategy(
31
+ ... source_languages=["en"],
32
+ ... target_languages=["de", "es", "fr"]
33
+ ... )
34
+ >>> strategy.prepare(teacher, student, train_data)
35
+ """
36
+
37
+ # Source languages (teacher understands these)
38
+ source_languages: list[str] = field(default_factory=lambda: ["en"])
39
+
40
+ # Target languages (student should learn these)
41
+ target_languages: list[str] = field(default_factory=list)
42
+
43
+ # Parallel sentence datasets
44
+ parallel_datasets: list[str] = field(
45
+ default_factory=lambda: [
46
+ "sentence-transformers/parallel-sentences-talks",
47
+ ]
48
+ )
49
+
50
+ # Maximum sentences per language pair
51
+ max_sentences_per_language: int = 500000
52
+
53
+ # Loss function
54
+ loss_fn: str = "mse"
55
+
56
+ def __post_init__(self):
57
+ self.teacher = None
58
+ self.student = None
59
+ self._train_datasets = None
60
+ self._eval_datasets = None
61
+
62
+ def prepare(
63
+ self,
64
+ teacher: SentenceTransformer,
65
+ student: SentenceTransformer,
66
+ batch_size: int = 64,
67
+ ) -> None:
68
+ """
69
+ Prepare for multilingual distillation.
70
+
71
+ Args:
72
+ teacher: Teacher model (monolingual).
73
+ student: Student model (multilingual).
74
+ batch_size: Batch size for encoding.
75
+ """
76
+ self.teacher = teacher
77
+ self.student = student
78
+ self.batch_size = batch_size
79
+
80
+ def load_parallel_data(self) -> tuple[DatasetDict, DatasetDict]:
81
+ """
82
+ Load parallel sentence data for training.
83
+
84
+ Returns:
85
+ Tuple of (train_datasets, eval_datasets) as DatasetDict.
86
+ """
87
+ train_datasets = DatasetDict()
88
+ eval_datasets = DatasetDict()
89
+
90
+ for source_lang in self.source_languages:
91
+ for target_lang in self.target_languages:
92
+ subset = f"{source_lang}-{target_lang}"
93
+
94
+ for dataset_name in self.parallel_datasets:
95
+ try:
96
+ train_data = load_dataset(dataset_name, subset, split="train")
97
+
98
+ # Limit size
99
+ if len(train_data) > self.max_sentences_per_language:
100
+ train_data = train_data.select(range(self.max_sentences_per_language))
101
+
102
+ # Try to get eval split
103
+ try:
104
+ eval_data = load_dataset(dataset_name, subset, split="dev")
105
+ if len(eval_data) > 1000:
106
+ eval_data = eval_data.select(range(1000))
107
+ except Exception:
108
+ # Split from train
109
+ split_data = train_data.train_test_split(test_size=1000, shuffle=True)
110
+ train_data = split_data["train"]
111
+ eval_data = split_data["test"]
112
+
113
+ train_datasets[subset] = train_data
114
+ eval_datasets[subset] = eval_data
115
+
116
+ logger.info(f"Loaded {len(train_data)} training samples for {subset}")
117
+
118
+ except Exception as e:
119
+ logger.warning(f"Could not load {dataset_name}/{subset}: {e}")
120
+
121
+ self._train_datasets = train_datasets
122
+ self._eval_datasets = eval_datasets
123
+
124
+ return train_datasets, eval_datasets
125
+
126
+ def prepare_dataset(
127
+ self,
128
+ dataset: Any,
129
+ source_col: str = "english",
130
+ target_col: str = "non_english",
131
+ ) -> Any:
132
+ """
133
+ Prepare dataset with teacher embeddings.
134
+
135
+ Args:
136
+ dataset: Dataset with parallel sentences.
137
+ source_col: Column name for source language sentences.
138
+ target_col: Column name for target language sentences.
139
+
140
+ Returns:
141
+ Dataset with teacher embeddings added.
142
+ """
143
+
144
+ def add_teacher_embeddings(batch):
145
+ source_sentences = batch[source_col]
146
+ with torch.no_grad():
147
+ embeddings = self.teacher.encode(
148
+ source_sentences,
149
+ batch_size=self.batch_size,
150
+ show_progress_bar=False,
151
+ convert_to_numpy=True,
152
+ )
153
+ return {"label": embeddings.tolist()}
154
+
155
+ return dataset.map(add_teacher_embeddings, batched=True, batch_size=10000)
156
+
157
+ def get_loss_function(self):
158
+ """Get the loss function."""
159
+ if self.loss_fn == "mse":
160
+ return DistillationLosses.mse_loss
161
+ elif self.loss_fn == "cosine":
162
+ return DistillationLosses.cosine_loss
163
+ else:
164
+ return DistillationLosses.mse_loss
165
+
166
+ def compute_loss(
167
+ self,
168
+ student_output: torch.Tensor,
169
+ teacher_output: torch.Tensor,
170
+ batch: dict,
171
+ ) -> torch.Tensor:
172
+ """
173
+ Compute multilingual distillation loss.
174
+
175
+ The loss encourages:
176
+ 1. Student source embeddings to match teacher source embeddings
177
+ 2. Student target embeddings to match teacher source embeddings
178
+
179
+ Args:
180
+ student_output: Student embeddings (for source or target text).
181
+ teacher_output: Teacher embeddings (for source text).
182
+ batch: Batch containing source and target sentences.
183
+
184
+ Returns:
185
+ Loss value.
186
+ """
187
+ loss_fn = self.get_loss_function()
188
+ return loss_fn(student_output, teacher_output)
189
+
190
+
191
+ class MultilingualDistilTrainer:
192
+ """
193
+ Trainer for multilingual distillation.
194
+
195
+ Extends a monolingual teacher to multiple languages via parallel sentence training.
196
+
197
+ Example:
198
+ >>> trainer = MultilingualDistilTrainer(
199
+ ... teacher_model="paraphrase-distilroberta-base-v2",
200
+ ... student_model="xlm-roberta-base",
201
+ ... )
202
+ >>> trainer.add_languages(["de", "es", "fr"])
203
+ >>> trainer.train()
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ teacher_model: str | SentenceTransformer,
209
+ student_model: str | SentenceTransformer,
210
+ source_languages: list[str] | None = None,
211
+ target_languages: list[str] | None = None,
212
+ output_dir: str = "./multilingual_model",
213
+ ):
214
+ """
215
+ Initialize the trainer.
216
+
217
+ Args:
218
+ teacher_model: Teacher model (monolingual, e.g., English).
219
+ student_model: Student model (multilingual base).
220
+ source_languages: Languages the teacher understands.
221
+ target_languages: Languages to extend to.
222
+ output_dir: Output directory for the trained model.
223
+ """
224
+ if isinstance(teacher_model, str):
225
+ self.teacher = SentenceTransformer(teacher_model)
226
+ else:
227
+ self.teacher = teacher_model
228
+
229
+ if isinstance(student_model, str):
230
+ self.student = SentenceTransformer(student_model)
231
+ else:
232
+ self.student = student_model
233
+
234
+ self.source_languages = source_languages or ["en"]
235
+ self.target_languages = target_languages or []
236
+ self.output_dir = output_dir
237
+
238
+ self.strategy = MultilingualDistillationStrategy(
239
+ source_languages=self.source_languages,
240
+ target_languages=self.target_languages,
241
+ )
242
+
243
+ def add_languages(self, languages: list[str]) -> None:
244
+ """Add target languages to extend to."""
245
+ self.target_languages.extend(languages)
246
+ self.strategy.target_languages = self.target_languages
247
+
248
+ def train(
249
+ self,
250
+ num_epochs: int = 5,
251
+ batch_size: int = 64,
252
+ learning_rate: float = 2e-5,
253
+ ) -> None:
254
+ """
255
+ Train the multilingual model.
256
+
257
+ Args:
258
+ num_epochs: Number of training epochs.
259
+ batch_size: Training batch size.
260
+ learning_rate: Learning rate.
261
+ """
262
+ logger.info("Starting multilingual distillation training...")
263
+
264
+ # Prepare strategy
265
+ self.strategy.prepare(self.teacher, self.student, batch_size)
266
+
267
+ # Load data
268
+ train_datasets, eval_datasets = self.strategy.load_parallel_data()
269
+
270
+ # Prepare datasets with teacher embeddings
271
+ for subset in train_datasets:
272
+ logger.info(f"Preparing {subset} with teacher embeddings...")
273
+ train_datasets[subset] = self.strategy.prepare_dataset(train_datasets[subset])
274
+
275
+ logger.info("Training prepared. Use DistilTrainer for actual training.")
276
+
277
+ # Save the strategy and datasets for use with DistilTrainer
278
+ self._train_datasets = train_datasets
279
+ self._eval_datasets = eval_datasets
280
+
281
+ def save_model(self, path: str | None = None) -> None:
282
+ """Save the trained model."""
283
+ save_path = path or self.output_dir
284
+ self.student.save(save_path)
285
+ logger.info(f"Model saved to {save_path}")
@@ -0,0 +1,211 @@
1
+ """Distillation strategies."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Callable
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ from sentence_transformers import SentenceTransformer
13
+
14
+ from distil_trainer.distillation.losses import DistillationLosses
15
+
16
+
17
+ class DistillationStrategy(ABC):
18
+ """Base class for distillation strategies."""
19
+
20
+ @abstractmethod
21
+ def prepare(self, teacher: Any, student: Any, data: Any) -> None:
22
+ """Prepare for distillation."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def get_loss_function(self) -> Callable:
27
+ """Get the loss function for this strategy."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def compute_loss(
32
+ self,
33
+ student_output: torch.Tensor,
34
+ teacher_output: torch.Tensor,
35
+ batch: dict,
36
+ ) -> torch.Tensor:
37
+ """Compute the distillation loss."""
38
+ pass
39
+
40
+
41
+ @dataclass
42
+ class EmbeddingDistillationStrategy(DistillationStrategy):
43
+ """
44
+ Strategy for distilling embedding models.
45
+
46
+ The student learns to produce embeddings similar to the teacher.
47
+ Uses MSE or cosine loss between student and teacher embeddings.
48
+
49
+ Example:
50
+ >>> strategy = EmbeddingDistillationStrategy(loss_fn="mse")
51
+ >>> strategy.prepare(teacher, student, train_data)
52
+ >>> loss_fn = strategy.get_loss_function()
53
+ """
54
+
55
+ # Loss function
56
+ loss_fn: str = "mse" # "mse", "cosine", "combined"
57
+
58
+ # Whether to precompute teacher embeddings
59
+ precompute_embeddings: bool = True
60
+
61
+ # PCA for dimension reduction
62
+ use_pca: bool = True
63
+ pca_components: int | None = None
64
+
65
+ # Dataset requirements
66
+ required_columns: list[str] = field(default_factory=lambda: ["sentence"])
67
+
68
+ def prepare(self, teacher: Any, student: Any, data: Any) -> None:
69
+ """Prepare for embedding distillation."""
70
+ self.teacher = teacher
71
+ self.student = student
72
+
73
+ def get_loss_function(self) -> Callable:
74
+ """Get the loss function."""
75
+ if self.loss_fn == "mse":
76
+ return DistillationLosses.mse_loss
77
+ elif self.loss_fn == "cosine":
78
+ return DistillationLosses.cosine_loss
79
+ elif self.loss_fn == "combined":
80
+ def combined_loss(student_out, teacher_out):
81
+ mse = DistillationLosses.mse_loss(student_out, teacher_out)
82
+ cosine = DistillationLosses.cosine_loss(student_out, teacher_out)
83
+ return 0.5 * mse + 0.5 * cosine
84
+ return combined_loss
85
+ else:
86
+ raise ValueError(f"Unknown loss function: {self.loss_fn}")
87
+
88
+ def compute_loss(
89
+ self,
90
+ student_output: torch.Tensor,
91
+ teacher_output: torch.Tensor,
92
+ batch: dict,
93
+ ) -> torch.Tensor:
94
+ """Compute embedding distillation loss."""
95
+ loss_fn = self.get_loss_function()
96
+ return loss_fn(student_output, teacher_output)
97
+
98
+
99
+ @dataclass
100
+ class LogitDistillationStrategy(DistillationStrategy):
101
+ """
102
+ Strategy for distilling via logit matching.
103
+
104
+ Uses KL divergence between softmax distributions of student and teacher.
105
+
106
+ Example:
107
+ >>> strategy = LogitDistillationStrategy(temperature=2.0)
108
+ >>> loss = strategy.compute_loss(student_logits, teacher_logits, batch)
109
+ """
110
+
111
+ # Temperature for softmax
112
+ temperature: float = 1.0
113
+
114
+ # Whether to also include hard label loss
115
+ alpha: float = 0.5 # Weight for soft targets vs hard targets
116
+
117
+ def prepare(self, teacher: Any, student: Any, data: Any) -> None:
118
+ """Prepare for logit distillation."""
119
+ self.teacher = teacher
120
+ self.student = student
121
+
122
+ def get_loss_function(self) -> Callable:
123
+ """Get the loss function."""
124
+ def loss_fn(student_logits, teacher_logits):
125
+ return DistillationLosses.kl_divergence_loss(
126
+ student_logits, teacher_logits, self.temperature
127
+ )
128
+ return loss_fn
129
+
130
+ def compute_loss(
131
+ self,
132
+ student_output: torch.Tensor,
133
+ teacher_output: torch.Tensor,
134
+ batch: dict,
135
+ ) -> torch.Tensor:
136
+ """Compute logit distillation loss."""
137
+ soft_loss = DistillationLosses.kl_divergence_loss(
138
+ student_output, teacher_output, self.temperature
139
+ )
140
+
141
+ # If hard labels available, combine with cross-entropy
142
+ if "labels" in batch and self.alpha < 1.0:
143
+ import torch.nn.functional as F
144
+ hard_loss = F.cross_entropy(student_output, batch["labels"])
145
+ return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
146
+
147
+ return soft_loss
148
+
149
+
150
+ @dataclass
151
+ class RankingDistillationStrategy(DistillationStrategy):
152
+ """
153
+ Strategy for distilling retrieval models with ranking loss.
154
+
155
+ Example:
156
+ >>> strategy = RankingDistillationStrategy(in_batch_negatives=True)
157
+ >>> loss = strategy.compute_loss(query_emb, pos_emb, batch)
158
+ """
159
+
160
+ # Whether to use in-batch negatives
161
+ in_batch_negatives: bool = True
162
+
163
+ # Number of hard negatives per sample
164
+ hard_negatives: int = 5
165
+
166
+ # Margin for triplet loss
167
+ margin: float = 0.5
168
+
169
+ def prepare(self, teacher: Any, student: Any, data: Any) -> None:
170
+ """Prepare for ranking distillation."""
171
+ self.teacher = teacher
172
+ self.student = student
173
+
174
+ def get_loss_function(self) -> Callable:
175
+ """Get the loss function."""
176
+ def loss_fn(query_emb, pos_emb, neg_emb=None):
177
+ return DistillationLosses.ranking_loss(
178
+ query_emb, pos_emb, neg_emb,
179
+ margin=self.margin,
180
+ in_batch_negatives=self.in_batch_negatives,
181
+ )
182
+ return loss_fn
183
+
184
+ def compute_loss(
185
+ self,
186
+ student_output: torch.Tensor,
187
+ teacher_output: torch.Tensor,
188
+ batch: dict,
189
+ ) -> torch.Tensor:
190
+ """Compute ranking distillation loss."""
191
+ # For ranking, student_output contains query embeddings
192
+ # teacher_output contains positive embeddings (from teacher)
193
+ query_embeddings = student_output
194
+
195
+ # Get positive embeddings from batch or compute
196
+ if "positive_embeddings" in batch:
197
+ positive_embeddings = batch["positive_embeddings"]
198
+ else:
199
+ # Use teacher embeddings as targets
200
+ positive_embeddings = teacher_output
201
+
202
+ # Get negative embeddings if available
203
+ negative_embeddings = batch.get("negative_embeddings", None)
204
+
205
+ return DistillationLosses.ranking_loss(
206
+ query_embeddings,
207
+ positive_embeddings,
208
+ negative_embeddings,
209
+ margin=self.margin,
210
+ in_batch_negatives=self.in_batch_negatives,
211
+ )
@@ -0,0 +1,19 @@
1
+ """Evaluation module for distillation training."""
2
+
3
+ from distil_trainer.evaluation.evaluators import (
4
+ EmbeddingSimilarityEvaluator,
5
+ MSEEvaluator,
6
+ TranslationEvaluator,
7
+ SequentialEvaluator,
8
+ )
9
+ from distil_trainer.evaluation.metrics import DistillationMetrics
10
+ from distil_trainer.evaluation.benchmarks import BenchmarkRunner
11
+
12
+ __all__ = [
13
+ "EmbeddingSimilarityEvaluator",
14
+ "MSEEvaluator",
15
+ "TranslationEvaluator",
16
+ "SequentialEvaluator",
17
+ "DistillationMetrics",
18
+ "BenchmarkRunner",
19
+ ]
@@ -0,0 +1,86 @@
1
+ """Benchmark runners for model evaluation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class BenchmarkRunner:
14
+ """Run standard benchmarks for evaluation."""
15
+
16
+ def __init__(self, model: SentenceTransformer | Any):
17
+ """
18
+ Initialize the benchmark runner.
19
+
20
+ Args:
21
+ model: Model to evaluate.
22
+ """
23
+ self.model = model
24
+
25
+ def run_stsb(self, split: str = "test") -> dict[str, float]:
26
+ """Run STS Benchmark evaluation."""
27
+ from datasets import load_dataset
28
+ from distil_trainer.evaluation.evaluators import EmbeddingSimilarityEvaluator
29
+
30
+ dataset = load_dataset("sentence-transformers/stsb", split=split)
31
+
32
+ evaluator = EmbeddingSimilarityEvaluator(
33
+ sentences1=dataset["sentence1"],
34
+ sentences2=dataset["sentence2"],
35
+ scores=dataset["score"],
36
+ name=f"stsb-{split}",
37
+ )
38
+
39
+ return evaluator(self.model)
40
+
41
+ def run_mteb(
42
+ self,
43
+ tasks: list[str] | None = None,
44
+ languages: list[str] | None = None,
45
+ ) -> dict[str, float]:
46
+ """
47
+ Run MTEB benchmark suite.
48
+
49
+ Requires mteb package: pip install mteb
50
+ """
51
+ try:
52
+ from mteb import MTEB
53
+ except ImportError:
54
+ logger.warning("MTEB not installed. Install with: pip install mteb")
55
+ return {}
56
+
57
+ if tasks is None:
58
+ tasks = ["STS12", "STS13", "STS14", "STS15", "STS16"]
59
+
60
+ evaluation = MTEB(tasks=tasks)
61
+ results = evaluation.run(self.model, output_folder=None)
62
+
63
+ # Flatten results
64
+ flat_results = {}
65
+ for task_result in results:
66
+ task_name = task_result.task_name
67
+ for metric, value in task_result.scores.items():
68
+ flat_results[f"{task_name}_{metric}"] = value
69
+
70
+ return flat_results
71
+
72
+ def run_all(self) -> dict[str, dict[str, float]]:
73
+ """Run all available benchmarks."""
74
+ results = {}
75
+
76
+ try:
77
+ results["stsb"] = self.run_stsb()
78
+ except Exception as e:
79
+ logger.warning(f"STSB evaluation failed: {e}")
80
+
81
+ try:
82
+ results["mteb"] = self.run_mteb()
83
+ except Exception as e:
84
+ logger.warning(f"MTEB evaluation failed: {e}")
85
+
86
+ return results