distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,163 @@
1
+ """Built-in dataset loaders."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from datasets import Dataset, load_dataset
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class DatasetLoaders:
14
+ """Built-in loaders for common datasets."""
15
+
16
+ @staticmethod
17
+ def load_allnli(split: str = "train", config: str = "pair-score") -> Dataset:
18
+ """
19
+ Load AllNLI dataset for sentence distillation.
20
+
21
+ Args:
22
+ split: Dataset split ("train", "dev", "test").
23
+ config: Configuration name.
24
+
25
+ Returns:
26
+ Dataset with sentence pairs.
27
+ """
28
+ logger.info(f"Loading AllNLI dataset ({split})")
29
+ return load_dataset("sentence-transformers/all-nli", config, split=split)
30
+
31
+ @staticmethod
32
+ def load_wikipedia_sentences(
33
+ language: str = "en",
34
+ max_samples: int | None = None,
35
+ ) -> Dataset:
36
+ """
37
+ Load Wikipedia sentences dataset.
38
+
39
+ Args:
40
+ language: Language code (only "en" currently supported).
41
+ max_samples: Maximum number of samples.
42
+
43
+ Returns:
44
+ Dataset with sentences.
45
+ """
46
+ logger.info("Loading Wikipedia sentences dataset")
47
+ dataset = load_dataset(
48
+ "sentence-transformers/wikipedia-en-sentences",
49
+ split="train",
50
+ )
51
+
52
+ if max_samples is not None and len(dataset) > max_samples:
53
+ dataset = dataset.select(range(max_samples))
54
+
55
+ return dataset
56
+
57
+ @staticmethod
58
+ def load_stsb(split: str = "validation") -> Dataset:
59
+ """
60
+ Load STS Benchmark for evaluation.
61
+
62
+ Args:
63
+ split: Dataset split ("train", "validation", "test").
64
+
65
+ Returns:
66
+ Dataset with sentence pairs and similarity scores.
67
+ """
68
+ logger.info(f"Loading STS Benchmark ({split})")
69
+ return load_dataset("sentence-transformers/stsb", split=split)
70
+
71
+ @staticmethod
72
+ def load_parallel_sentences(
73
+ source_lang: str = "en",
74
+ target_lang: str = "de",
75
+ dataset: str = "talks",
76
+ split: str = "train",
77
+ max_samples: int | None = None,
78
+ ) -> Dataset:
79
+ """
80
+ Load parallel sentences for multilingual training.
81
+
82
+ Args:
83
+ source_lang: Source language code.
84
+ target_lang: Target language code.
85
+ dataset: Dataset name (talks, europarl, tatoeba, etc.).
86
+ split: Dataset split.
87
+ max_samples: Maximum number of samples.
88
+
89
+ Returns:
90
+ Dataset with parallel sentences.
91
+ """
92
+ dataset_name = f"sentence-transformers/parallel-sentences-{dataset}"
93
+ subset = f"{source_lang}-{target_lang}"
94
+
95
+ logger.info(f"Loading parallel sentences: {dataset_name}/{subset}")
96
+
97
+ try:
98
+ data = load_dataset(dataset_name, subset, split=split)
99
+ except Exception:
100
+ # Try reversed language pair
101
+ subset = f"{target_lang}-{source_lang}"
102
+ data = load_dataset(dataset_name, subset, split=split)
103
+
104
+ if max_samples is not None and len(data) > max_samples:
105
+ data = data.select(range(max_samples))
106
+
107
+ return data
108
+
109
+ @staticmethod
110
+ def load_specter() -> tuple[Dataset, Dataset, Dataset]:
111
+ """
112
+ Load Specter triplet dataset for retrieval training.
113
+
114
+ Returns:
115
+ Tuple of (train, validation, test) datasets.
116
+ """
117
+ logger.info("Loading Specter dataset")
118
+ train = load_dataset("allenai/specter", split="train")
119
+ val = load_dataset("allenai/specter", split="validation")
120
+ test = load_dataset("allenai/specter", split="test")
121
+ return train, val, test
122
+
123
+ @staticmethod
124
+ def load_bespoke_stratos(max_samples: int | None = None) -> Dataset:
125
+ """
126
+ Load Bespoke-Stratos reasoning dataset.
127
+
128
+ Args:
129
+ max_samples: Maximum number of samples.
130
+
131
+ Returns:
132
+ Dataset with reasoning chains.
133
+ """
134
+ logger.info("Loading Bespoke-Stratos dataset")
135
+ dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split="train")
136
+
137
+ if max_samples is not None and len(dataset) > max_samples:
138
+ dataset = dataset.select(range(max_samples))
139
+
140
+ return dataset
141
+
142
+ @staticmethod
143
+ def load_msmarco(
144
+ split: str = "train",
145
+ max_samples: int | None = None,
146
+ ) -> Dataset:
147
+ """
148
+ Load MS MARCO passage ranking dataset.
149
+
150
+ Args:
151
+ split: Dataset split.
152
+ max_samples: Maximum number of samples.
153
+
154
+ Returns:
155
+ Dataset with queries and passages.
156
+ """
157
+ logger.info(f"Loading MS MARCO dataset ({split})")
158
+ dataset = load_dataset("ms_marco", "v2.1", split=split)
159
+
160
+ if max_samples is not None and len(dataset) > max_samples:
161
+ dataset = dataset.select(range(max_samples))
162
+
163
+ return dataset
@@ -0,0 +1,21 @@
1
+ """Distillation module for knowledge transfer."""
2
+
3
+ from distil_trainer.distillation.losses import (
4
+ DistillationLosses,
5
+ CombinedDistillationLoss,
6
+ )
7
+ from distil_trainer.distillation.strategies import (
8
+ EmbeddingDistillationStrategy,
9
+ LogitDistillationStrategy,
10
+ )
11
+ from distil_trainer.distillation.multilingual import (
12
+ MultilingualDistillationStrategy,
13
+ )
14
+
15
+ __all__ = [
16
+ "DistillationLosses",
17
+ "CombinedDistillationLoss",
18
+ "EmbeddingDistillationStrategy",
19
+ "LogitDistillationStrategy",
20
+ "MultilingualDistillationStrategy",
21
+ ]
@@ -0,0 +1,345 @@
1
+ """Distillation loss functions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Callable
9
+
10
+
11
+ class DistillationLosses:
12
+ """Collection of distillation loss functions."""
13
+
14
+ @staticmethod
15
+ def mse_loss(
16
+ student_output: torch.Tensor,
17
+ teacher_output: torch.Tensor,
18
+ ) -> torch.Tensor:
19
+ """
20
+ Mean Squared Error loss between embeddings.
21
+
22
+ Args:
23
+ student_output: Student model embeddings [batch_size, dim].
24
+ teacher_output: Teacher model embeddings [batch_size, dim].
25
+
26
+ Returns:
27
+ MSE loss value.
28
+ """
29
+ return F.mse_loss(student_output, teacher_output)
30
+
31
+ @staticmethod
32
+ def cosine_loss(
33
+ student_output: torch.Tensor,
34
+ teacher_output: torch.Tensor,
35
+ ) -> torch.Tensor:
36
+ """
37
+ Cosine similarity loss (1 - cosine_similarity).
38
+
39
+ Args:
40
+ student_output: Student model embeddings [batch_size, dim].
41
+ teacher_output: Teacher model embeddings [batch_size, dim].
42
+
43
+ Returns:
44
+ Cosine loss value.
45
+ """
46
+ similarity = F.cosine_similarity(student_output, teacher_output, dim=-1)
47
+ return (1 - similarity).mean()
48
+
49
+ @staticmethod
50
+ def kl_divergence_loss(
51
+ student_logits: torch.Tensor,
52
+ teacher_logits: torch.Tensor,
53
+ temperature: float = 1.0,
54
+ ) -> torch.Tensor:
55
+ """
56
+ KL divergence loss on softmax distributions.
57
+
58
+ Args:
59
+ student_logits: Student model logits.
60
+ teacher_logits: Teacher model logits.
61
+ temperature: Temperature for softmax (higher = softer).
62
+
63
+ Returns:
64
+ KL divergence loss value.
65
+ """
66
+ student_probs = F.log_softmax(student_logits / temperature, dim=-1)
67
+ teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
68
+ loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean")
69
+ return loss * (temperature ** 2)
70
+
71
+ @staticmethod
72
+ def ranking_loss(
73
+ query_embeddings: torch.Tensor,
74
+ positive_embeddings: torch.Tensor,
75
+ negative_embeddings: torch.Tensor | None = None,
76
+ margin: float = 0.5,
77
+ in_batch_negatives: bool = True,
78
+ ) -> torch.Tensor:
79
+ """
80
+ Ranking loss for retrieval models.
81
+
82
+ Args:
83
+ query_embeddings: Query embeddings [batch_size, dim].
84
+ positive_embeddings: Positive document embeddings [batch_size, dim].
85
+ negative_embeddings: Negative document embeddings [batch_size, num_neg, dim].
86
+ margin: Margin for triplet loss.
87
+ in_batch_negatives: Whether to use in-batch negatives.
88
+
89
+ Returns:
90
+ Ranking loss value.
91
+ """
92
+ batch_size = query_embeddings.size(0)
93
+
94
+ # Positive scores
95
+ pos_scores = F.cosine_similarity(query_embeddings, positive_embeddings, dim=-1)
96
+
97
+ if in_batch_negatives:
98
+ # Use other positives in batch as negatives
99
+ # Similarity matrix: [batch_size, batch_size]
100
+ sim_matrix = torch.mm(
101
+ F.normalize(query_embeddings, p=2, dim=-1),
102
+ F.normalize(positive_embeddings, p=2, dim=-1).t()
103
+ )
104
+
105
+ # Diagonal contains positive similarities
106
+ # Off-diagonal contains negative similarities
107
+ labels = torch.arange(batch_size, device=query_embeddings.device)
108
+
109
+ # Cross-entropy loss with in-batch negatives
110
+ loss = F.cross_entropy(sim_matrix / 0.05, labels) # Temperature-scaled
111
+
112
+ elif negative_embeddings is not None:
113
+ # Use provided negatives
114
+ num_negatives = negative_embeddings.size(1)
115
+
116
+ # Compute negative scores
117
+ neg_scores = torch.bmm(
118
+ query_embeddings.unsqueeze(1),
119
+ negative_embeddings.transpose(1, 2)
120
+ ).squeeze(1) # [batch_size, num_neg]
121
+
122
+ # Triplet margin loss
123
+ pos_scores_expanded = pos_scores.unsqueeze(1).expand(-1, num_negatives)
124
+ loss = F.relu(margin - pos_scores_expanded + neg_scores).mean()
125
+
126
+ else:
127
+ raise ValueError("Either in_batch_negatives or negative_embeddings required")
128
+
129
+ return loss
130
+
131
+ @staticmethod
132
+ def contrastive_loss(
133
+ embeddings1: torch.Tensor,
134
+ embeddings2: torch.Tensor,
135
+ labels: torch.Tensor,
136
+ margin: float = 0.5,
137
+ ) -> torch.Tensor:
138
+ """
139
+ Contrastive loss for pairs.
140
+
141
+ Args:
142
+ embeddings1: First set of embeddings [batch_size, dim].
143
+ embeddings2: Second set of embeddings [batch_size, dim].
144
+ labels: Binary labels (1 = similar, 0 = dissimilar) [batch_size].
145
+ margin: Margin for dissimilar pairs.
146
+
147
+ Returns:
148
+ Contrastive loss value.
149
+ """
150
+ distances = F.pairwise_distance(embeddings1, embeddings2)
151
+
152
+ # Similar pairs: minimize distance
153
+ # Dissimilar pairs: maximize distance (up to margin)
154
+ loss = labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2)
155
+
156
+ return loss.mean()
157
+
158
+ @staticmethod
159
+ def intermediate_layer_loss(
160
+ student_hidden_states: tuple[torch.Tensor, ...],
161
+ teacher_hidden_states: tuple[torch.Tensor, ...],
162
+ layer_mapping: dict[int, int],
163
+ ) -> torch.Tensor:
164
+ """
165
+ Loss on intermediate layer representations.
166
+
167
+ Args:
168
+ student_hidden_states: Tuple of student hidden states per layer.
169
+ teacher_hidden_states: Tuple of teacher hidden states per layer.
170
+ layer_mapping: Maps student layer indices to teacher layer indices.
171
+
172
+ Returns:
173
+ Intermediate layer loss value.
174
+ """
175
+ total_loss = 0.0
176
+ num_layers = 0
177
+
178
+ for student_idx, teacher_idx in layer_mapping.items():
179
+ if student_idx < len(student_hidden_states) and teacher_idx < len(teacher_hidden_states):
180
+ student_hidden = student_hidden_states[student_idx]
181
+ teacher_hidden = teacher_hidden_states[teacher_idx]
182
+
183
+ # Handle dimension mismatch with projection
184
+ if student_hidden.size(-1) != teacher_hidden.size(-1):
185
+ # Simple mean pooling to match dimensions
186
+ if student_hidden.size(-1) < teacher_hidden.size(-1):
187
+ teacher_hidden = teacher_hidden[..., :student_hidden.size(-1)]
188
+ else:
189
+ student_hidden = student_hidden[..., :teacher_hidden.size(-1)]
190
+
191
+ total_loss += F.mse_loss(student_hidden, teacher_hidden)
192
+ num_layers += 1
193
+
194
+ return total_loss / max(num_layers, 1)
195
+
196
+ @staticmethod
197
+ def attention_transfer_loss(
198
+ student_attentions: tuple[torch.Tensor, ...],
199
+ teacher_attentions: tuple[torch.Tensor, ...],
200
+ layer_mapping: dict[int, int] | None = None,
201
+ ) -> torch.Tensor:
202
+ """
203
+ Loss on attention patterns.
204
+
205
+ Args:
206
+ student_attentions: Tuple of student attention weights per layer.
207
+ teacher_attentions: Tuple of teacher attention weights per layer.
208
+ layer_mapping: Maps student layer indices to teacher layer indices.
209
+
210
+ Returns:
211
+ Attention transfer loss value.
212
+ """
213
+ if layer_mapping is None:
214
+ # Default: map corresponding layers
215
+ min_layers = min(len(student_attentions), len(teacher_attentions))
216
+ layer_mapping = {i: i for i in range(min_layers)}
217
+
218
+ total_loss = 0.0
219
+ num_layers = 0
220
+
221
+ for student_idx, teacher_idx in layer_mapping.items():
222
+ if student_idx < len(student_attentions) and teacher_idx < len(teacher_attentions):
223
+ student_attn = student_attentions[student_idx]
224
+ teacher_attn = teacher_attentions[teacher_idx]
225
+
226
+ # Attention matrices: [batch, heads, seq, seq]
227
+ # Average over heads if different number of heads
228
+ if student_attn.size(1) != teacher_attn.size(1):
229
+ student_attn = student_attn.mean(dim=1, keepdim=True)
230
+ teacher_attn = teacher_attn.mean(dim=1, keepdim=True)
231
+
232
+ total_loss += F.mse_loss(student_attn, teacher_attn)
233
+ num_layers += 1
234
+
235
+ return total_loss / max(num_layers, 1)
236
+
237
+
238
+ class CombinedDistillationLoss(nn.Module):
239
+ """Combine multiple distillation losses with weights."""
240
+
241
+ def __init__(
242
+ self,
243
+ logit_weight: float = 1.0,
244
+ embedding_weight: float = 1.0,
245
+ intermediate_weight: float = 0.0,
246
+ attention_weight: float = 0.0,
247
+ temperature: float = 1.0,
248
+ layer_mapping: dict[int, int] | None = None,
249
+ ):
250
+ """
251
+ Initialize combined loss.
252
+
253
+ Args:
254
+ logit_weight: Weight for logit distillation loss.
255
+ embedding_weight: Weight for embedding MSE loss.
256
+ intermediate_weight: Weight for intermediate layer loss.
257
+ attention_weight: Weight for attention transfer loss.
258
+ temperature: Temperature for KL divergence.
259
+ layer_mapping: Layer mapping for intermediate/attention losses.
260
+ """
261
+ super().__init__()
262
+ self.logit_weight = logit_weight
263
+ self.embedding_weight = embedding_weight
264
+ self.intermediate_weight = intermediate_weight
265
+ self.attention_weight = attention_weight
266
+ self.temperature = temperature
267
+ self.layer_mapping = layer_mapping or {}
268
+
269
+ def forward(
270
+ self,
271
+ student_output: torch.Tensor | dict,
272
+ teacher_output: torch.Tensor | dict,
273
+ ) -> torch.Tensor:
274
+ """
275
+ Compute weighted combination of losses.
276
+
277
+ Args:
278
+ student_output: Student model output (tensor or dict with multiple outputs).
279
+ teacher_output: Teacher model output (tensor or dict with multiple outputs).
280
+
281
+ Returns:
282
+ Combined loss value.
283
+ """
284
+ total_loss = torch.tensor(0.0, device=self._get_device(student_output))
285
+
286
+ # Handle simple tensor outputs (embedding distillation)
287
+ if isinstance(student_output, torch.Tensor) and isinstance(teacher_output, torch.Tensor):
288
+ if self.embedding_weight > 0:
289
+ total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
290
+ student_output, teacher_output
291
+ )
292
+ return total_loss
293
+
294
+ # Handle dict outputs with multiple components
295
+ if isinstance(student_output, dict) and isinstance(teacher_output, dict):
296
+ # Logit loss
297
+ if self.logit_weight > 0 and "logits" in student_output and "logits" in teacher_output:
298
+ total_loss = total_loss + self.logit_weight * DistillationLosses.kl_divergence_loss(
299
+ student_output["logits"],
300
+ teacher_output["logits"],
301
+ self.temperature,
302
+ )
303
+
304
+ # Embedding loss
305
+ if self.embedding_weight > 0:
306
+ if "last_hidden_state" in student_output and "last_hidden_state" in teacher_output:
307
+ total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
308
+ student_output["last_hidden_state"],
309
+ teacher_output["last_hidden_state"],
310
+ )
311
+ elif "embeddings" in student_output and "embeddings" in teacher_output:
312
+ total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
313
+ student_output["embeddings"],
314
+ teacher_output["embeddings"],
315
+ )
316
+
317
+ # Intermediate layer loss
318
+ if self.intermediate_weight > 0 and self.layer_mapping:
319
+ if "hidden_states" in student_output and "hidden_states" in teacher_output:
320
+ total_loss = total_loss + self.intermediate_weight * DistillationLosses.intermediate_layer_loss(
321
+ student_output["hidden_states"],
322
+ teacher_output["hidden_states"],
323
+ self.layer_mapping,
324
+ )
325
+
326
+ # Attention loss
327
+ if self.attention_weight > 0:
328
+ if "attentions" in student_output and "attentions" in teacher_output:
329
+ total_loss = total_loss + self.attention_weight * DistillationLosses.attention_transfer_loss(
330
+ student_output["attentions"],
331
+ teacher_output["attentions"],
332
+ self.layer_mapping,
333
+ )
334
+
335
+ return total_loss
336
+
337
+ def _get_device(self, output) -> torch.device:
338
+ """Get device from output."""
339
+ if isinstance(output, torch.Tensor):
340
+ return output.device
341
+ if isinstance(output, dict):
342
+ for v in output.values():
343
+ if isinstance(v, torch.Tensor):
344
+ return v.device
345
+ return torch.device("cpu")