distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""Built-in dataset loaders."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from datasets import Dataset, load_dataset
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DatasetLoaders:
|
|
14
|
+
"""Built-in loaders for common datasets."""
|
|
15
|
+
|
|
16
|
+
@staticmethod
|
|
17
|
+
def load_allnli(split: str = "train", config: str = "pair-score") -> Dataset:
|
|
18
|
+
"""
|
|
19
|
+
Load AllNLI dataset for sentence distillation.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
split: Dataset split ("train", "dev", "test").
|
|
23
|
+
config: Configuration name.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Dataset with sentence pairs.
|
|
27
|
+
"""
|
|
28
|
+
logger.info(f"Loading AllNLI dataset ({split})")
|
|
29
|
+
return load_dataset("sentence-transformers/all-nli", config, split=split)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def load_wikipedia_sentences(
|
|
33
|
+
language: str = "en",
|
|
34
|
+
max_samples: int | None = None,
|
|
35
|
+
) -> Dataset:
|
|
36
|
+
"""
|
|
37
|
+
Load Wikipedia sentences dataset.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
language: Language code (only "en" currently supported).
|
|
41
|
+
max_samples: Maximum number of samples.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dataset with sentences.
|
|
45
|
+
"""
|
|
46
|
+
logger.info("Loading Wikipedia sentences dataset")
|
|
47
|
+
dataset = load_dataset(
|
|
48
|
+
"sentence-transformers/wikipedia-en-sentences",
|
|
49
|
+
split="train",
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if max_samples is not None and len(dataset) > max_samples:
|
|
53
|
+
dataset = dataset.select(range(max_samples))
|
|
54
|
+
|
|
55
|
+
return dataset
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def load_stsb(split: str = "validation") -> Dataset:
|
|
59
|
+
"""
|
|
60
|
+
Load STS Benchmark for evaluation.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
split: Dataset split ("train", "validation", "test").
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dataset with sentence pairs and similarity scores.
|
|
67
|
+
"""
|
|
68
|
+
logger.info(f"Loading STS Benchmark ({split})")
|
|
69
|
+
return load_dataset("sentence-transformers/stsb", split=split)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def load_parallel_sentences(
|
|
73
|
+
source_lang: str = "en",
|
|
74
|
+
target_lang: str = "de",
|
|
75
|
+
dataset: str = "talks",
|
|
76
|
+
split: str = "train",
|
|
77
|
+
max_samples: int | None = None,
|
|
78
|
+
) -> Dataset:
|
|
79
|
+
"""
|
|
80
|
+
Load parallel sentences for multilingual training.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
source_lang: Source language code.
|
|
84
|
+
target_lang: Target language code.
|
|
85
|
+
dataset: Dataset name (talks, europarl, tatoeba, etc.).
|
|
86
|
+
split: Dataset split.
|
|
87
|
+
max_samples: Maximum number of samples.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Dataset with parallel sentences.
|
|
91
|
+
"""
|
|
92
|
+
dataset_name = f"sentence-transformers/parallel-sentences-{dataset}"
|
|
93
|
+
subset = f"{source_lang}-{target_lang}"
|
|
94
|
+
|
|
95
|
+
logger.info(f"Loading parallel sentences: {dataset_name}/{subset}")
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
data = load_dataset(dataset_name, subset, split=split)
|
|
99
|
+
except Exception:
|
|
100
|
+
# Try reversed language pair
|
|
101
|
+
subset = f"{target_lang}-{source_lang}"
|
|
102
|
+
data = load_dataset(dataset_name, subset, split=split)
|
|
103
|
+
|
|
104
|
+
if max_samples is not None and len(data) > max_samples:
|
|
105
|
+
data = data.select(range(max_samples))
|
|
106
|
+
|
|
107
|
+
return data
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def load_specter() -> tuple[Dataset, Dataset, Dataset]:
|
|
111
|
+
"""
|
|
112
|
+
Load Specter triplet dataset for retrieval training.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Tuple of (train, validation, test) datasets.
|
|
116
|
+
"""
|
|
117
|
+
logger.info("Loading Specter dataset")
|
|
118
|
+
train = load_dataset("allenai/specter", split="train")
|
|
119
|
+
val = load_dataset("allenai/specter", split="validation")
|
|
120
|
+
test = load_dataset("allenai/specter", split="test")
|
|
121
|
+
return train, val, test
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def load_bespoke_stratos(max_samples: int | None = None) -> Dataset:
|
|
125
|
+
"""
|
|
126
|
+
Load Bespoke-Stratos reasoning dataset.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
max_samples: Maximum number of samples.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Dataset with reasoning chains.
|
|
133
|
+
"""
|
|
134
|
+
logger.info("Loading Bespoke-Stratos dataset")
|
|
135
|
+
dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split="train")
|
|
136
|
+
|
|
137
|
+
if max_samples is not None and len(dataset) > max_samples:
|
|
138
|
+
dataset = dataset.select(range(max_samples))
|
|
139
|
+
|
|
140
|
+
return dataset
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
143
|
+
def load_msmarco(
|
|
144
|
+
split: str = "train",
|
|
145
|
+
max_samples: int | None = None,
|
|
146
|
+
) -> Dataset:
|
|
147
|
+
"""
|
|
148
|
+
Load MS MARCO passage ranking dataset.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
split: Dataset split.
|
|
152
|
+
max_samples: Maximum number of samples.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Dataset with queries and passages.
|
|
156
|
+
"""
|
|
157
|
+
logger.info(f"Loading MS MARCO dataset ({split})")
|
|
158
|
+
dataset = load_dataset("ms_marco", "v2.1", split=split)
|
|
159
|
+
|
|
160
|
+
if max_samples is not None and len(dataset) > max_samples:
|
|
161
|
+
dataset = dataset.select(range(max_samples))
|
|
162
|
+
|
|
163
|
+
return dataset
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Distillation module for knowledge transfer."""
|
|
2
|
+
|
|
3
|
+
from distil_trainer.distillation.losses import (
|
|
4
|
+
DistillationLosses,
|
|
5
|
+
CombinedDistillationLoss,
|
|
6
|
+
)
|
|
7
|
+
from distil_trainer.distillation.strategies import (
|
|
8
|
+
EmbeddingDistillationStrategy,
|
|
9
|
+
LogitDistillationStrategy,
|
|
10
|
+
)
|
|
11
|
+
from distil_trainer.distillation.multilingual import (
|
|
12
|
+
MultilingualDistillationStrategy,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"DistillationLosses",
|
|
17
|
+
"CombinedDistillationLoss",
|
|
18
|
+
"EmbeddingDistillationStrategy",
|
|
19
|
+
"LogitDistillationStrategy",
|
|
20
|
+
"MultilingualDistillationStrategy",
|
|
21
|
+
]
|
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
"""Distillation loss functions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DistillationLosses:
|
|
12
|
+
"""Collection of distillation loss functions."""
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def mse_loss(
|
|
16
|
+
student_output: torch.Tensor,
|
|
17
|
+
teacher_output: torch.Tensor,
|
|
18
|
+
) -> torch.Tensor:
|
|
19
|
+
"""
|
|
20
|
+
Mean Squared Error loss between embeddings.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
student_output: Student model embeddings [batch_size, dim].
|
|
24
|
+
teacher_output: Teacher model embeddings [batch_size, dim].
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
MSE loss value.
|
|
28
|
+
"""
|
|
29
|
+
return F.mse_loss(student_output, teacher_output)
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def cosine_loss(
|
|
33
|
+
student_output: torch.Tensor,
|
|
34
|
+
teacher_output: torch.Tensor,
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Cosine similarity loss (1 - cosine_similarity).
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
student_output: Student model embeddings [batch_size, dim].
|
|
41
|
+
teacher_output: Teacher model embeddings [batch_size, dim].
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Cosine loss value.
|
|
45
|
+
"""
|
|
46
|
+
similarity = F.cosine_similarity(student_output, teacher_output, dim=-1)
|
|
47
|
+
return (1 - similarity).mean()
|
|
48
|
+
|
|
49
|
+
@staticmethod
|
|
50
|
+
def kl_divergence_loss(
|
|
51
|
+
student_logits: torch.Tensor,
|
|
52
|
+
teacher_logits: torch.Tensor,
|
|
53
|
+
temperature: float = 1.0,
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
"""
|
|
56
|
+
KL divergence loss on softmax distributions.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
student_logits: Student model logits.
|
|
60
|
+
teacher_logits: Teacher model logits.
|
|
61
|
+
temperature: Temperature for softmax (higher = softer).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
KL divergence loss value.
|
|
65
|
+
"""
|
|
66
|
+
student_probs = F.log_softmax(student_logits / temperature, dim=-1)
|
|
67
|
+
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
|
|
68
|
+
loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean")
|
|
69
|
+
return loss * (temperature ** 2)
|
|
70
|
+
|
|
71
|
+
@staticmethod
|
|
72
|
+
def ranking_loss(
|
|
73
|
+
query_embeddings: torch.Tensor,
|
|
74
|
+
positive_embeddings: torch.Tensor,
|
|
75
|
+
negative_embeddings: torch.Tensor | None = None,
|
|
76
|
+
margin: float = 0.5,
|
|
77
|
+
in_batch_negatives: bool = True,
|
|
78
|
+
) -> torch.Tensor:
|
|
79
|
+
"""
|
|
80
|
+
Ranking loss for retrieval models.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
query_embeddings: Query embeddings [batch_size, dim].
|
|
84
|
+
positive_embeddings: Positive document embeddings [batch_size, dim].
|
|
85
|
+
negative_embeddings: Negative document embeddings [batch_size, num_neg, dim].
|
|
86
|
+
margin: Margin for triplet loss.
|
|
87
|
+
in_batch_negatives: Whether to use in-batch negatives.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Ranking loss value.
|
|
91
|
+
"""
|
|
92
|
+
batch_size = query_embeddings.size(0)
|
|
93
|
+
|
|
94
|
+
# Positive scores
|
|
95
|
+
pos_scores = F.cosine_similarity(query_embeddings, positive_embeddings, dim=-1)
|
|
96
|
+
|
|
97
|
+
if in_batch_negatives:
|
|
98
|
+
# Use other positives in batch as negatives
|
|
99
|
+
# Similarity matrix: [batch_size, batch_size]
|
|
100
|
+
sim_matrix = torch.mm(
|
|
101
|
+
F.normalize(query_embeddings, p=2, dim=-1),
|
|
102
|
+
F.normalize(positive_embeddings, p=2, dim=-1).t()
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Diagonal contains positive similarities
|
|
106
|
+
# Off-diagonal contains negative similarities
|
|
107
|
+
labels = torch.arange(batch_size, device=query_embeddings.device)
|
|
108
|
+
|
|
109
|
+
# Cross-entropy loss with in-batch negatives
|
|
110
|
+
loss = F.cross_entropy(sim_matrix / 0.05, labels) # Temperature-scaled
|
|
111
|
+
|
|
112
|
+
elif negative_embeddings is not None:
|
|
113
|
+
# Use provided negatives
|
|
114
|
+
num_negatives = negative_embeddings.size(1)
|
|
115
|
+
|
|
116
|
+
# Compute negative scores
|
|
117
|
+
neg_scores = torch.bmm(
|
|
118
|
+
query_embeddings.unsqueeze(1),
|
|
119
|
+
negative_embeddings.transpose(1, 2)
|
|
120
|
+
).squeeze(1) # [batch_size, num_neg]
|
|
121
|
+
|
|
122
|
+
# Triplet margin loss
|
|
123
|
+
pos_scores_expanded = pos_scores.unsqueeze(1).expand(-1, num_negatives)
|
|
124
|
+
loss = F.relu(margin - pos_scores_expanded + neg_scores).mean()
|
|
125
|
+
|
|
126
|
+
else:
|
|
127
|
+
raise ValueError("Either in_batch_negatives or negative_embeddings required")
|
|
128
|
+
|
|
129
|
+
return loss
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def contrastive_loss(
|
|
133
|
+
embeddings1: torch.Tensor,
|
|
134
|
+
embeddings2: torch.Tensor,
|
|
135
|
+
labels: torch.Tensor,
|
|
136
|
+
margin: float = 0.5,
|
|
137
|
+
) -> torch.Tensor:
|
|
138
|
+
"""
|
|
139
|
+
Contrastive loss for pairs.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
embeddings1: First set of embeddings [batch_size, dim].
|
|
143
|
+
embeddings2: Second set of embeddings [batch_size, dim].
|
|
144
|
+
labels: Binary labels (1 = similar, 0 = dissimilar) [batch_size].
|
|
145
|
+
margin: Margin for dissimilar pairs.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Contrastive loss value.
|
|
149
|
+
"""
|
|
150
|
+
distances = F.pairwise_distance(embeddings1, embeddings2)
|
|
151
|
+
|
|
152
|
+
# Similar pairs: minimize distance
|
|
153
|
+
# Dissimilar pairs: maximize distance (up to margin)
|
|
154
|
+
loss = labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2)
|
|
155
|
+
|
|
156
|
+
return loss.mean()
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def intermediate_layer_loss(
|
|
160
|
+
student_hidden_states: tuple[torch.Tensor, ...],
|
|
161
|
+
teacher_hidden_states: tuple[torch.Tensor, ...],
|
|
162
|
+
layer_mapping: dict[int, int],
|
|
163
|
+
) -> torch.Tensor:
|
|
164
|
+
"""
|
|
165
|
+
Loss on intermediate layer representations.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
student_hidden_states: Tuple of student hidden states per layer.
|
|
169
|
+
teacher_hidden_states: Tuple of teacher hidden states per layer.
|
|
170
|
+
layer_mapping: Maps student layer indices to teacher layer indices.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Intermediate layer loss value.
|
|
174
|
+
"""
|
|
175
|
+
total_loss = 0.0
|
|
176
|
+
num_layers = 0
|
|
177
|
+
|
|
178
|
+
for student_idx, teacher_idx in layer_mapping.items():
|
|
179
|
+
if student_idx < len(student_hidden_states) and teacher_idx < len(teacher_hidden_states):
|
|
180
|
+
student_hidden = student_hidden_states[student_idx]
|
|
181
|
+
teacher_hidden = teacher_hidden_states[teacher_idx]
|
|
182
|
+
|
|
183
|
+
# Handle dimension mismatch with projection
|
|
184
|
+
if student_hidden.size(-1) != teacher_hidden.size(-1):
|
|
185
|
+
# Simple mean pooling to match dimensions
|
|
186
|
+
if student_hidden.size(-1) < teacher_hidden.size(-1):
|
|
187
|
+
teacher_hidden = teacher_hidden[..., :student_hidden.size(-1)]
|
|
188
|
+
else:
|
|
189
|
+
student_hidden = student_hidden[..., :teacher_hidden.size(-1)]
|
|
190
|
+
|
|
191
|
+
total_loss += F.mse_loss(student_hidden, teacher_hidden)
|
|
192
|
+
num_layers += 1
|
|
193
|
+
|
|
194
|
+
return total_loss / max(num_layers, 1)
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def attention_transfer_loss(
|
|
198
|
+
student_attentions: tuple[torch.Tensor, ...],
|
|
199
|
+
teacher_attentions: tuple[torch.Tensor, ...],
|
|
200
|
+
layer_mapping: dict[int, int] | None = None,
|
|
201
|
+
) -> torch.Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Loss on attention patterns.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
student_attentions: Tuple of student attention weights per layer.
|
|
207
|
+
teacher_attentions: Tuple of teacher attention weights per layer.
|
|
208
|
+
layer_mapping: Maps student layer indices to teacher layer indices.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Attention transfer loss value.
|
|
212
|
+
"""
|
|
213
|
+
if layer_mapping is None:
|
|
214
|
+
# Default: map corresponding layers
|
|
215
|
+
min_layers = min(len(student_attentions), len(teacher_attentions))
|
|
216
|
+
layer_mapping = {i: i for i in range(min_layers)}
|
|
217
|
+
|
|
218
|
+
total_loss = 0.0
|
|
219
|
+
num_layers = 0
|
|
220
|
+
|
|
221
|
+
for student_idx, teacher_idx in layer_mapping.items():
|
|
222
|
+
if student_idx < len(student_attentions) and teacher_idx < len(teacher_attentions):
|
|
223
|
+
student_attn = student_attentions[student_idx]
|
|
224
|
+
teacher_attn = teacher_attentions[teacher_idx]
|
|
225
|
+
|
|
226
|
+
# Attention matrices: [batch, heads, seq, seq]
|
|
227
|
+
# Average over heads if different number of heads
|
|
228
|
+
if student_attn.size(1) != teacher_attn.size(1):
|
|
229
|
+
student_attn = student_attn.mean(dim=1, keepdim=True)
|
|
230
|
+
teacher_attn = teacher_attn.mean(dim=1, keepdim=True)
|
|
231
|
+
|
|
232
|
+
total_loss += F.mse_loss(student_attn, teacher_attn)
|
|
233
|
+
num_layers += 1
|
|
234
|
+
|
|
235
|
+
return total_loss / max(num_layers, 1)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class CombinedDistillationLoss(nn.Module):
|
|
239
|
+
"""Combine multiple distillation losses with weights."""
|
|
240
|
+
|
|
241
|
+
def __init__(
|
|
242
|
+
self,
|
|
243
|
+
logit_weight: float = 1.0,
|
|
244
|
+
embedding_weight: float = 1.0,
|
|
245
|
+
intermediate_weight: float = 0.0,
|
|
246
|
+
attention_weight: float = 0.0,
|
|
247
|
+
temperature: float = 1.0,
|
|
248
|
+
layer_mapping: dict[int, int] | None = None,
|
|
249
|
+
):
|
|
250
|
+
"""
|
|
251
|
+
Initialize combined loss.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
logit_weight: Weight for logit distillation loss.
|
|
255
|
+
embedding_weight: Weight for embedding MSE loss.
|
|
256
|
+
intermediate_weight: Weight for intermediate layer loss.
|
|
257
|
+
attention_weight: Weight for attention transfer loss.
|
|
258
|
+
temperature: Temperature for KL divergence.
|
|
259
|
+
layer_mapping: Layer mapping for intermediate/attention losses.
|
|
260
|
+
"""
|
|
261
|
+
super().__init__()
|
|
262
|
+
self.logit_weight = logit_weight
|
|
263
|
+
self.embedding_weight = embedding_weight
|
|
264
|
+
self.intermediate_weight = intermediate_weight
|
|
265
|
+
self.attention_weight = attention_weight
|
|
266
|
+
self.temperature = temperature
|
|
267
|
+
self.layer_mapping = layer_mapping or {}
|
|
268
|
+
|
|
269
|
+
def forward(
|
|
270
|
+
self,
|
|
271
|
+
student_output: torch.Tensor | dict,
|
|
272
|
+
teacher_output: torch.Tensor | dict,
|
|
273
|
+
) -> torch.Tensor:
|
|
274
|
+
"""
|
|
275
|
+
Compute weighted combination of losses.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
student_output: Student model output (tensor or dict with multiple outputs).
|
|
279
|
+
teacher_output: Teacher model output (tensor or dict with multiple outputs).
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Combined loss value.
|
|
283
|
+
"""
|
|
284
|
+
total_loss = torch.tensor(0.0, device=self._get_device(student_output))
|
|
285
|
+
|
|
286
|
+
# Handle simple tensor outputs (embedding distillation)
|
|
287
|
+
if isinstance(student_output, torch.Tensor) and isinstance(teacher_output, torch.Tensor):
|
|
288
|
+
if self.embedding_weight > 0:
|
|
289
|
+
total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
|
|
290
|
+
student_output, teacher_output
|
|
291
|
+
)
|
|
292
|
+
return total_loss
|
|
293
|
+
|
|
294
|
+
# Handle dict outputs with multiple components
|
|
295
|
+
if isinstance(student_output, dict) and isinstance(teacher_output, dict):
|
|
296
|
+
# Logit loss
|
|
297
|
+
if self.logit_weight > 0 and "logits" in student_output and "logits" in teacher_output:
|
|
298
|
+
total_loss = total_loss + self.logit_weight * DistillationLosses.kl_divergence_loss(
|
|
299
|
+
student_output["logits"],
|
|
300
|
+
teacher_output["logits"],
|
|
301
|
+
self.temperature,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# Embedding loss
|
|
305
|
+
if self.embedding_weight > 0:
|
|
306
|
+
if "last_hidden_state" in student_output and "last_hidden_state" in teacher_output:
|
|
307
|
+
total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
|
|
308
|
+
student_output["last_hidden_state"],
|
|
309
|
+
teacher_output["last_hidden_state"],
|
|
310
|
+
)
|
|
311
|
+
elif "embeddings" in student_output and "embeddings" in teacher_output:
|
|
312
|
+
total_loss = total_loss + self.embedding_weight * DistillationLosses.mse_loss(
|
|
313
|
+
student_output["embeddings"],
|
|
314
|
+
teacher_output["embeddings"],
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Intermediate layer loss
|
|
318
|
+
if self.intermediate_weight > 0 and self.layer_mapping:
|
|
319
|
+
if "hidden_states" in student_output and "hidden_states" in teacher_output:
|
|
320
|
+
total_loss = total_loss + self.intermediate_weight * DistillationLosses.intermediate_layer_loss(
|
|
321
|
+
student_output["hidden_states"],
|
|
322
|
+
teacher_output["hidden_states"],
|
|
323
|
+
self.layer_mapping,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# Attention loss
|
|
327
|
+
if self.attention_weight > 0:
|
|
328
|
+
if "attentions" in student_output and "attentions" in teacher_output:
|
|
329
|
+
total_loss = total_loss + self.attention_weight * DistillationLosses.attention_transfer_loss(
|
|
330
|
+
student_output["attentions"],
|
|
331
|
+
teacher_output["attentions"],
|
|
332
|
+
self.layer_mapping,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return total_loss
|
|
336
|
+
|
|
337
|
+
def _get_device(self, output) -> torch.device:
|
|
338
|
+
"""Get device from output."""
|
|
339
|
+
if isinstance(output, torch.Tensor):
|
|
340
|
+
return output.device
|
|
341
|
+
if isinstance(output, dict):
|
|
342
|
+
for v in output.values():
|
|
343
|
+
if isinstance(v, torch.Tensor):
|
|
344
|
+
return v.device
|
|
345
|
+
return torch.device("cpu")
|