distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Depth pruning (layer reduction) for transformer models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
from transformers import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
from distil_trainer.pruning.importance import ImportanceEstimator
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class DepthPruner:
|
|
21
|
+
"""
|
|
22
|
+
Handles layer removal from transformer models.
|
|
23
|
+
|
|
24
|
+
Depth pruning removes entire transformer layers to reduce model size
|
|
25
|
+
while preserving the overall architecture.
|
|
26
|
+
|
|
27
|
+
Example:
|
|
28
|
+
>>> pruner = DepthPruner(model)
|
|
29
|
+
>>> importance = pruner.compute_layer_importance(calibration_data)
|
|
30
|
+
>>> layers_to_keep = pruner.select_layers_to_keep(num_layers=8)
|
|
31
|
+
>>> pruned_model = pruner.prune(layers_to_keep=layers_to_keep)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
37
|
+
calibration_data: list[str] | None = None,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initialize the DepthPruner.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
model: The model to prune.
|
|
44
|
+
calibration_data: Optional calibration data for importance estimation.
|
|
45
|
+
"""
|
|
46
|
+
self.model = model
|
|
47
|
+
self.calibration_data = calibration_data
|
|
48
|
+
self._transformer = self._get_transformer_module()
|
|
49
|
+
self._original_num_layers = self._get_num_layers()
|
|
50
|
+
|
|
51
|
+
def _get_transformer_module(self) -> nn.Module:
|
|
52
|
+
"""Get the underlying transformer module."""
|
|
53
|
+
if isinstance(self.model, SentenceTransformer):
|
|
54
|
+
# SentenceTransformer wraps a transformer model
|
|
55
|
+
for module in self.model.modules():
|
|
56
|
+
if hasattr(module, "encoder") and hasattr(module.encoder, "layer"):
|
|
57
|
+
return module
|
|
58
|
+
if hasattr(module, "layers"):
|
|
59
|
+
return module
|
|
60
|
+
# Try the first module which is typically the transformer
|
|
61
|
+
return self.model._first_module()
|
|
62
|
+
else:
|
|
63
|
+
return self.model
|
|
64
|
+
|
|
65
|
+
def _get_encoder_layers(self) -> nn.ModuleList:
|
|
66
|
+
"""Get the encoder layers from the transformer."""
|
|
67
|
+
transformer = self._transformer
|
|
68
|
+
|
|
69
|
+
# Try different attribute names used by different models
|
|
70
|
+
if hasattr(transformer, "encoder") and hasattr(transformer.encoder, "layer"):
|
|
71
|
+
return transformer.encoder.layer
|
|
72
|
+
if hasattr(transformer, "encoder") and hasattr(transformer.encoder, "layers"):
|
|
73
|
+
return transformer.encoder.layers
|
|
74
|
+
if hasattr(transformer, "layers"):
|
|
75
|
+
return transformer.layers
|
|
76
|
+
if hasattr(transformer, "auto_model"):
|
|
77
|
+
auto_model = transformer.auto_model
|
|
78
|
+
if hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layer"):
|
|
79
|
+
return auto_model.encoder.layer
|
|
80
|
+
if hasattr(auto_model, "layers"):
|
|
81
|
+
return auto_model.layers
|
|
82
|
+
|
|
83
|
+
raise ValueError("Could not find encoder layers in model")
|
|
84
|
+
|
|
85
|
+
def _get_num_layers(self) -> int:
|
|
86
|
+
"""Get the number of layers in the model."""
|
|
87
|
+
return len(self._get_encoder_layers())
|
|
88
|
+
|
|
89
|
+
def compute_layer_importance(
|
|
90
|
+
self,
|
|
91
|
+
calibration_data: list[str] | None = None,
|
|
92
|
+
method: Literal["activation", "gradient", "cosine_similarity", "lm_loss"] = "activation",
|
|
93
|
+
num_samples: int = 1024,
|
|
94
|
+
) -> dict[int, float]:
|
|
95
|
+
"""
|
|
96
|
+
Compute importance scores for each layer.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
calibration_data: Sentences to use for calibration.
|
|
100
|
+
method: Importance estimation method.
|
|
101
|
+
num_samples: Number of samples to use.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Dictionary mapping layer index to importance score.
|
|
105
|
+
"""
|
|
106
|
+
data = calibration_data or self.calibration_data
|
|
107
|
+
if data is None:
|
|
108
|
+
raise ValueError("Calibration data required for importance estimation")
|
|
109
|
+
|
|
110
|
+
estimator = ImportanceEstimator(self.model)
|
|
111
|
+
|
|
112
|
+
if method == "activation":
|
|
113
|
+
importance = estimator.activation_based_layer_importance(data[:num_samples])
|
|
114
|
+
elif method == "cosine_similarity":
|
|
115
|
+
importance = estimator.drop_layer_importance(data[:num_samples])
|
|
116
|
+
else:
|
|
117
|
+
importance = estimator.activation_based_layer_importance(data[:num_samples])
|
|
118
|
+
|
|
119
|
+
return importance
|
|
120
|
+
|
|
121
|
+
def select_layers_to_keep(
|
|
122
|
+
self,
|
|
123
|
+
num_layers: int | None = None,
|
|
124
|
+
ratio: float | None = None,
|
|
125
|
+
strategy: Literal["first", "last", "even", "importance"] = "importance",
|
|
126
|
+
importance_scores: dict[int, float] | None = None,
|
|
127
|
+
) -> list[int]:
|
|
128
|
+
"""
|
|
129
|
+
Select which layers to keep based on the strategy.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
num_layers: Number of layers to keep.
|
|
133
|
+
ratio: Ratio of layers to keep (alternative to num_layers).
|
|
134
|
+
strategy: Layer selection strategy.
|
|
135
|
+
importance_scores: Precomputed importance scores (for 'importance' strategy).
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of layer indices to keep.
|
|
139
|
+
"""
|
|
140
|
+
total_layers = self._original_num_layers
|
|
141
|
+
|
|
142
|
+
if num_layers is None and ratio is not None:
|
|
143
|
+
num_layers = int(total_layers * ratio)
|
|
144
|
+
if num_layers is None:
|
|
145
|
+
raise ValueError("Either num_layers or ratio must be specified")
|
|
146
|
+
|
|
147
|
+
num_layers = min(num_layers, total_layers)
|
|
148
|
+
|
|
149
|
+
if strategy == "first":
|
|
150
|
+
return list(range(num_layers))
|
|
151
|
+
|
|
152
|
+
elif strategy == "last":
|
|
153
|
+
return list(range(total_layers - num_layers, total_layers))
|
|
154
|
+
|
|
155
|
+
elif strategy == "even":
|
|
156
|
+
# Evenly distribute layers
|
|
157
|
+
if num_layers == 1:
|
|
158
|
+
return [0]
|
|
159
|
+
step = (total_layers - 1) / (num_layers - 1)
|
|
160
|
+
return [int(round(i * step)) for i in range(num_layers)]
|
|
161
|
+
|
|
162
|
+
elif strategy == "importance":
|
|
163
|
+
if importance_scores is None:
|
|
164
|
+
if self.calibration_data is None:
|
|
165
|
+
# Fall back to even distribution
|
|
166
|
+
logger.warning("No importance scores or calibration data, using even distribution")
|
|
167
|
+
return self.select_layers_to_keep(num_layers=num_layers, strategy="even")
|
|
168
|
+
importance_scores = self.compute_layer_importance()
|
|
169
|
+
|
|
170
|
+
# Sort layers by importance and keep the most important ones
|
|
171
|
+
sorted_layers = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
|
|
172
|
+
layers_to_keep = sorted([layer_idx for layer_idx, _ in sorted_layers[:num_layers]])
|
|
173
|
+
return layers_to_keep
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
raise ValueError(f"Unknown strategy: {strategy}")
|
|
177
|
+
|
|
178
|
+
def prune(
|
|
179
|
+
self,
|
|
180
|
+
layers_to_keep: list[int] | None = None,
|
|
181
|
+
num_layers_to_keep: int | None = None,
|
|
182
|
+
layers_to_drop: list[int] | None = None,
|
|
183
|
+
layer_selection: Literal["first", "last", "even", "importance", "custom"] = "custom",
|
|
184
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
185
|
+
"""
|
|
186
|
+
Create a new model with only the specified layers.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
layers_to_keep: Explicit list of layer indices to keep.
|
|
190
|
+
num_layers_to_keep: Number of layers to keep (alternative).
|
|
191
|
+
layers_to_drop: Layers to drop (alternative).
|
|
192
|
+
layer_selection: Strategy for selecting layers when using num_layers_to_keep.
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
New model with reduced layers.
|
|
196
|
+
"""
|
|
197
|
+
# Determine which layers to keep
|
|
198
|
+
if layers_to_keep is not None:
|
|
199
|
+
final_layers_to_keep = layers_to_keep
|
|
200
|
+
elif layers_to_drop is not None:
|
|
201
|
+
all_layers = set(range(self._original_num_layers))
|
|
202
|
+
final_layers_to_keep = sorted(all_layers - set(layers_to_drop))
|
|
203
|
+
elif num_layers_to_keep is not None:
|
|
204
|
+
if layer_selection == "custom":
|
|
205
|
+
layer_selection = "importance"
|
|
206
|
+
final_layers_to_keep = self.select_layers_to_keep(
|
|
207
|
+
num_layers=num_layers_to_keep,
|
|
208
|
+
strategy=layer_selection,
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError("Must specify layers_to_keep, num_layers_to_keep, or layers_to_drop")
|
|
212
|
+
|
|
213
|
+
logger.info(f"Keeping layers: {final_layers_to_keep}")
|
|
214
|
+
logger.info(f"Reducing from {self._original_num_layers} to {len(final_layers_to_keep)} layers")
|
|
215
|
+
|
|
216
|
+
# Create a deep copy of the model
|
|
217
|
+
pruned_model = copy.deepcopy(self.model)
|
|
218
|
+
|
|
219
|
+
# Get the encoder layers
|
|
220
|
+
if isinstance(pruned_model, SentenceTransformer):
|
|
221
|
+
transformer = pruned_model._first_module()
|
|
222
|
+
if hasattr(transformer, "auto_model"):
|
|
223
|
+
auto_model = transformer.auto_model
|
|
224
|
+
else:
|
|
225
|
+
auto_model = transformer
|
|
226
|
+
else:
|
|
227
|
+
auto_model = pruned_model
|
|
228
|
+
|
|
229
|
+
# Find and replace the layers
|
|
230
|
+
encoder_layers = None
|
|
231
|
+
if hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layer"):
|
|
232
|
+
encoder_layers = auto_model.encoder.layer
|
|
233
|
+
parent = auto_model.encoder
|
|
234
|
+
attr_name = "layer"
|
|
235
|
+
elif hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layers"):
|
|
236
|
+
encoder_layers = auto_model.encoder.layers
|
|
237
|
+
parent = auto_model.encoder
|
|
238
|
+
attr_name = "layers"
|
|
239
|
+
elif hasattr(auto_model, "layers"):
|
|
240
|
+
encoder_layers = auto_model.layers
|
|
241
|
+
parent = auto_model
|
|
242
|
+
attr_name = "layers"
|
|
243
|
+
|
|
244
|
+
if encoder_layers is None:
|
|
245
|
+
raise ValueError("Could not find encoder layers in model")
|
|
246
|
+
|
|
247
|
+
# Create new layer list with only kept layers
|
|
248
|
+
new_layers = nn.ModuleList([encoder_layers[i] for i in final_layers_to_keep])
|
|
249
|
+
setattr(parent, attr_name, new_layers)
|
|
250
|
+
|
|
251
|
+
# Update config if available
|
|
252
|
+
if hasattr(auto_model, "config"):
|
|
253
|
+
auto_model.config.num_hidden_layers = len(final_layers_to_keep)
|
|
254
|
+
|
|
255
|
+
# Log statistics
|
|
256
|
+
original_params = sum(p.numel() for p in self.model.parameters())
|
|
257
|
+
pruned_params = sum(p.numel() for p in pruned_model.parameters())
|
|
258
|
+
reduction = 1 - (pruned_params / original_params)
|
|
259
|
+
logger.info(f"Parameter reduction: {original_params:,} -> {pruned_params:,} ({reduction:.1%})")
|
|
260
|
+
|
|
261
|
+
return pruned_model
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""Importance estimation for pruning."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from transformers import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ImportanceEstimator:
|
|
19
|
+
"""
|
|
20
|
+
Methods for estimating component importance.
|
|
21
|
+
|
|
22
|
+
Supports importance estimation for:
|
|
23
|
+
- Layers (for depth pruning)
|
|
24
|
+
- Hidden dimensions (for width pruning)
|
|
25
|
+
- Intermediate dimensions (for MLP pruning)
|
|
26
|
+
- Attention heads (for head pruning)
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> estimator = ImportanceEstimator(model)
|
|
30
|
+
>>> layer_importance = estimator.activation_based_layer_importance(sentences)
|
|
31
|
+
>>> hidden_importance = estimator.hidden_dimension_importance(sentences)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, model: SentenceTransformer | PreTrainedModel):
|
|
35
|
+
"""
|
|
36
|
+
Initialize the ImportanceEstimator.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
model: The model to estimate importance for.
|
|
40
|
+
"""
|
|
41
|
+
self.model = model
|
|
42
|
+
self.device = next(model.parameters()).device
|
|
43
|
+
|
|
44
|
+
def activation_based_layer_importance(
|
|
45
|
+
self,
|
|
46
|
+
sentences: list[str],
|
|
47
|
+
batch_size: int = 32,
|
|
48
|
+
) -> dict[int, float]:
|
|
49
|
+
"""
|
|
50
|
+
Estimate layer importance based on activation magnitudes.
|
|
51
|
+
|
|
52
|
+
Higher activation magnitude suggests higher importance.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
sentences: Sentences to use for estimation.
|
|
56
|
+
batch_size: Batch size for inference.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Dictionary mapping layer index to importance score.
|
|
60
|
+
"""
|
|
61
|
+
self.model.eval()
|
|
62
|
+
|
|
63
|
+
# Get the transformer model
|
|
64
|
+
if isinstance(self.model, SentenceTransformer):
|
|
65
|
+
transformer = self.model._first_module()
|
|
66
|
+
if hasattr(transformer, "auto_model"):
|
|
67
|
+
auto_model = transformer.auto_model
|
|
68
|
+
else:
|
|
69
|
+
auto_model = transformer
|
|
70
|
+
tokenizer = self.model.tokenizer
|
|
71
|
+
else:
|
|
72
|
+
auto_model = self.model
|
|
73
|
+
tokenizer = None
|
|
74
|
+
|
|
75
|
+
# Get encoder layers
|
|
76
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
77
|
+
num_layers = len(encoder_layers)
|
|
78
|
+
|
|
79
|
+
# Track activations per layer
|
|
80
|
+
layer_activations = {i: 0.0 for i in range(num_layers)}
|
|
81
|
+
num_samples = 0
|
|
82
|
+
|
|
83
|
+
# Register hooks to capture activations
|
|
84
|
+
handles = []
|
|
85
|
+
activation_sums = {}
|
|
86
|
+
|
|
87
|
+
def make_hook(layer_idx):
|
|
88
|
+
def hook(module, input, output):
|
|
89
|
+
# Handle different output formats
|
|
90
|
+
if isinstance(output, tuple):
|
|
91
|
+
hidden_states = output[0]
|
|
92
|
+
else:
|
|
93
|
+
hidden_states = output
|
|
94
|
+
activation_sums[layer_idx] = activation_sums.get(layer_idx, 0) + hidden_states.abs().mean().item()
|
|
95
|
+
return hook
|
|
96
|
+
|
|
97
|
+
for i, layer in enumerate(encoder_layers):
|
|
98
|
+
handle = layer.register_forward_hook(make_hook(i))
|
|
99
|
+
handles.append(handle)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
for i in range(0, len(sentences), batch_size):
|
|
104
|
+
batch_sentences = sentences[i:i + batch_size]
|
|
105
|
+
|
|
106
|
+
if isinstance(self.model, SentenceTransformer):
|
|
107
|
+
self.model.encode(batch_sentences, convert_to_tensor=True)
|
|
108
|
+
else:
|
|
109
|
+
if tokenizer:
|
|
110
|
+
inputs = tokenizer(
|
|
111
|
+
batch_sentences,
|
|
112
|
+
padding=True,
|
|
113
|
+
truncation=True,
|
|
114
|
+
return_tensors="pt",
|
|
115
|
+
).to(self.device)
|
|
116
|
+
self.model(**inputs)
|
|
117
|
+
|
|
118
|
+
num_samples += 1
|
|
119
|
+
|
|
120
|
+
# Average the activation sums
|
|
121
|
+
for layer_idx in range(num_layers):
|
|
122
|
+
layer_activations[layer_idx] = activation_sums.get(layer_idx, 0) / max(num_samples, 1)
|
|
123
|
+
|
|
124
|
+
finally:
|
|
125
|
+
for handle in handles:
|
|
126
|
+
handle.remove()
|
|
127
|
+
|
|
128
|
+
return layer_activations
|
|
129
|
+
|
|
130
|
+
def drop_layer_importance(
|
|
131
|
+
self,
|
|
132
|
+
sentences: list[str],
|
|
133
|
+
batch_size: int = 32,
|
|
134
|
+
) -> dict[int, float]:
|
|
135
|
+
"""
|
|
136
|
+
Estimate layer importance by measuring embedding similarity when layer is dropped.
|
|
137
|
+
|
|
138
|
+
Lower similarity when dropped = higher importance.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
sentences: Sentences to use for estimation.
|
|
142
|
+
batch_size: Batch size for inference.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Dictionary mapping layer index to importance score.
|
|
146
|
+
"""
|
|
147
|
+
if not isinstance(self.model, SentenceTransformer):
|
|
148
|
+
raise ValueError("drop_layer_importance only works with SentenceTransformer models")
|
|
149
|
+
|
|
150
|
+
# Get reference embeddings
|
|
151
|
+
with torch.no_grad():
|
|
152
|
+
reference_embeddings = self.model.encode(
|
|
153
|
+
sentences[:min(len(sentences), 500)],
|
|
154
|
+
convert_to_tensor=True,
|
|
155
|
+
show_progress_bar=False,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Get transformer model
|
|
159
|
+
transformer = self.model._first_module()
|
|
160
|
+
if hasattr(transformer, "auto_model"):
|
|
161
|
+
auto_model = transformer.auto_model
|
|
162
|
+
else:
|
|
163
|
+
auto_model = transformer
|
|
164
|
+
|
|
165
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
166
|
+
num_layers = len(encoder_layers)
|
|
167
|
+
|
|
168
|
+
importance = {}
|
|
169
|
+
|
|
170
|
+
for layer_idx in range(num_layers):
|
|
171
|
+
# Temporarily remove the layer
|
|
172
|
+
original_layer = encoder_layers[layer_idx]
|
|
173
|
+
|
|
174
|
+
# Replace with identity
|
|
175
|
+
encoder_layers[layer_idx] = nn.Identity()
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
dropped_embeddings = self.model.encode(
|
|
180
|
+
sentences[:min(len(sentences), 500)],
|
|
181
|
+
convert_to_tensor=True,
|
|
182
|
+
show_progress_bar=False,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Compute cosine similarity
|
|
186
|
+
similarity = torch.nn.functional.cosine_similarity(
|
|
187
|
+
reference_embeddings, dropped_embeddings, dim=1
|
|
188
|
+
).mean().item()
|
|
189
|
+
|
|
190
|
+
# Lower similarity = higher importance
|
|
191
|
+
importance[layer_idx] = 1 - similarity
|
|
192
|
+
|
|
193
|
+
finally:
|
|
194
|
+
# Restore the layer
|
|
195
|
+
encoder_layers[layer_idx] = original_layer
|
|
196
|
+
|
|
197
|
+
return importance
|
|
198
|
+
|
|
199
|
+
def hidden_dimension_importance(
|
|
200
|
+
self,
|
|
201
|
+
sentences: list[str],
|
|
202
|
+
method: Literal["activation", "magnitude", "taylor"] = "activation",
|
|
203
|
+
batch_size: int = 32,
|
|
204
|
+
) -> torch.Tensor:
|
|
205
|
+
"""
|
|
206
|
+
Estimate importance of each hidden dimension.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
sentences: Sentences to use for estimation.
|
|
210
|
+
method: Importance estimation method.
|
|
211
|
+
batch_size: Batch size for inference.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Tensor of importance scores for each hidden dimension.
|
|
215
|
+
"""
|
|
216
|
+
self.model.eval()
|
|
217
|
+
|
|
218
|
+
if isinstance(self.model, SentenceTransformer):
|
|
219
|
+
transformer = self.model._first_module()
|
|
220
|
+
if hasattr(transformer, "auto_model"):
|
|
221
|
+
auto_model = transformer.auto_model
|
|
222
|
+
else:
|
|
223
|
+
auto_model = transformer
|
|
224
|
+
else:
|
|
225
|
+
auto_model = self.model
|
|
226
|
+
|
|
227
|
+
hidden_size = auto_model.config.hidden_size
|
|
228
|
+
importance = torch.zeros(hidden_size, device=self.device)
|
|
229
|
+
|
|
230
|
+
if method == "activation":
|
|
231
|
+
# Track output activations
|
|
232
|
+
num_samples = 0
|
|
233
|
+
|
|
234
|
+
with torch.no_grad():
|
|
235
|
+
for i in range(0, len(sentences), batch_size):
|
|
236
|
+
batch_sentences = sentences[i:i + batch_size]
|
|
237
|
+
|
|
238
|
+
if isinstance(self.model, SentenceTransformer):
|
|
239
|
+
embeddings = self.model.encode(
|
|
240
|
+
batch_sentences,
|
|
241
|
+
convert_to_tensor=True,
|
|
242
|
+
show_progress_bar=False,
|
|
243
|
+
)
|
|
244
|
+
importance += embeddings.abs().mean(dim=0)
|
|
245
|
+
else:
|
|
246
|
+
# For regular models, would need tokenizer
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
num_samples += 1
|
|
250
|
+
|
|
251
|
+
importance = importance / max(num_samples, 1)
|
|
252
|
+
|
|
253
|
+
elif method == "magnitude":
|
|
254
|
+
# Use weight magnitudes from final layer
|
|
255
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
256
|
+
last_layer = encoder_layers[-1]
|
|
257
|
+
|
|
258
|
+
# Sum up weight magnitudes
|
|
259
|
+
for name, param in last_layer.named_parameters():
|
|
260
|
+
if param.dim() >= 2:
|
|
261
|
+
# Aggregate across the hidden dimension
|
|
262
|
+
importance += param.abs().sum(dim=0)[:hidden_size] if param.size(-1) >= hidden_size else torch.zeros(hidden_size, device=self.device)
|
|
263
|
+
|
|
264
|
+
return importance
|
|
265
|
+
|
|
266
|
+
def intermediate_dimension_importance(
|
|
267
|
+
self,
|
|
268
|
+
sentences: list[str],
|
|
269
|
+
method: Literal["activation", "magnitude"] = "activation",
|
|
270
|
+
batch_size: int = 32,
|
|
271
|
+
) -> torch.Tensor:
|
|
272
|
+
"""
|
|
273
|
+
Estimate importance of each intermediate (MLP) dimension.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
sentences: Sentences to use for estimation.
|
|
277
|
+
method: Importance estimation method.
|
|
278
|
+
batch_size: Batch size for inference.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Tensor of importance scores for each intermediate dimension.
|
|
282
|
+
"""
|
|
283
|
+
self.model.eval()
|
|
284
|
+
|
|
285
|
+
if isinstance(self.model, SentenceTransformer):
|
|
286
|
+
transformer = self.model._first_module()
|
|
287
|
+
if hasattr(transformer, "auto_model"):
|
|
288
|
+
auto_model = transformer.auto_model
|
|
289
|
+
else:
|
|
290
|
+
auto_model = transformer
|
|
291
|
+
else:
|
|
292
|
+
auto_model = self.model
|
|
293
|
+
|
|
294
|
+
intermediate_size = auto_model.config.intermediate_size
|
|
295
|
+
importance = torch.zeros(intermediate_size, device=self.device)
|
|
296
|
+
|
|
297
|
+
if method == "magnitude":
|
|
298
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
299
|
+
|
|
300
|
+
for layer in encoder_layers:
|
|
301
|
+
if hasattr(layer, "intermediate") and hasattr(layer.intermediate, "dense"):
|
|
302
|
+
weight = layer.intermediate.dense.weight
|
|
303
|
+
importance += weight.abs().mean(dim=1)
|
|
304
|
+
|
|
305
|
+
importance = importance / len(encoder_layers)
|
|
306
|
+
|
|
307
|
+
return importance
|
|
308
|
+
|
|
309
|
+
def attention_head_importance(
|
|
310
|
+
self,
|
|
311
|
+
sentences: list[str],
|
|
312
|
+
method: Literal["activation", "attention_entropy"] = "activation",
|
|
313
|
+
batch_size: int = 32,
|
|
314
|
+
) -> torch.Tensor:
|
|
315
|
+
"""
|
|
316
|
+
Estimate importance of each attention head.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
sentences: Sentences to use for estimation.
|
|
320
|
+
method: Importance estimation method.
|
|
321
|
+
batch_size: Batch size for inference.
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Tensor of importance scores for each attention head.
|
|
325
|
+
"""
|
|
326
|
+
self.model.eval()
|
|
327
|
+
|
|
328
|
+
if isinstance(self.model, SentenceTransformer):
|
|
329
|
+
transformer = self.model._first_module()
|
|
330
|
+
if hasattr(transformer, "auto_model"):
|
|
331
|
+
auto_model = transformer.auto_model
|
|
332
|
+
else:
|
|
333
|
+
auto_model = transformer
|
|
334
|
+
else:
|
|
335
|
+
auto_model = self.model
|
|
336
|
+
|
|
337
|
+
num_heads = auto_model.config.num_attention_heads
|
|
338
|
+
importance = torch.zeros(num_heads, device=self.device)
|
|
339
|
+
|
|
340
|
+
if method == "activation":
|
|
341
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
342
|
+
head_dim = auto_model.config.hidden_size // num_heads
|
|
343
|
+
|
|
344
|
+
for layer in encoder_layers:
|
|
345
|
+
if hasattr(layer, "attention") and hasattr(layer.attention, "self"):
|
|
346
|
+
self_attn = layer.attention.self
|
|
347
|
+
if hasattr(self_attn, "query"):
|
|
348
|
+
weight = self_attn.query.weight
|
|
349
|
+
# Reshape to [num_heads, head_dim, hidden_size]
|
|
350
|
+
weight_per_head = weight.view(num_heads, head_dim, -1)
|
|
351
|
+
importance += weight_per_head.abs().mean(dim=(1, 2))
|
|
352
|
+
|
|
353
|
+
importance = importance / len(encoder_layers)
|
|
354
|
+
|
|
355
|
+
return importance
|
|
356
|
+
|
|
357
|
+
def _get_encoder_layers(self, model) -> nn.ModuleList:
|
|
358
|
+
"""Get encoder layers from the model."""
|
|
359
|
+
if hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
|
|
360
|
+
return model.encoder.layer
|
|
361
|
+
if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
|
|
362
|
+
return model.encoder.layers
|
|
363
|
+
if hasattr(model, "layers"):
|
|
364
|
+
return model.layers
|
|
365
|
+
raise ValueError("Could not find encoder layers")
|