distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,261 @@
1
+ """Depth pruning (layer reduction) for transformer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import logging
7
+ from typing import Literal
8
+
9
+ import torch
10
+ from torch import nn
11
+ from transformers import PreTrainedModel
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ from distil_trainer.pruning.importance import ImportanceEstimator
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DepthPruner:
21
+ """
22
+ Handles layer removal from transformer models.
23
+
24
+ Depth pruning removes entire transformer layers to reduce model size
25
+ while preserving the overall architecture.
26
+
27
+ Example:
28
+ >>> pruner = DepthPruner(model)
29
+ >>> importance = pruner.compute_layer_importance(calibration_data)
30
+ >>> layers_to_keep = pruner.select_layers_to_keep(num_layers=8)
31
+ >>> pruned_model = pruner.prune(layers_to_keep=layers_to_keep)
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ model: SentenceTransformer | PreTrainedModel,
37
+ calibration_data: list[str] | None = None,
38
+ ):
39
+ """
40
+ Initialize the DepthPruner.
41
+
42
+ Args:
43
+ model: The model to prune.
44
+ calibration_data: Optional calibration data for importance estimation.
45
+ """
46
+ self.model = model
47
+ self.calibration_data = calibration_data
48
+ self._transformer = self._get_transformer_module()
49
+ self._original_num_layers = self._get_num_layers()
50
+
51
+ def _get_transformer_module(self) -> nn.Module:
52
+ """Get the underlying transformer module."""
53
+ if isinstance(self.model, SentenceTransformer):
54
+ # SentenceTransformer wraps a transformer model
55
+ for module in self.model.modules():
56
+ if hasattr(module, "encoder") and hasattr(module.encoder, "layer"):
57
+ return module
58
+ if hasattr(module, "layers"):
59
+ return module
60
+ # Try the first module which is typically the transformer
61
+ return self.model._first_module()
62
+ else:
63
+ return self.model
64
+
65
+ def _get_encoder_layers(self) -> nn.ModuleList:
66
+ """Get the encoder layers from the transformer."""
67
+ transformer = self._transformer
68
+
69
+ # Try different attribute names used by different models
70
+ if hasattr(transformer, "encoder") and hasattr(transformer.encoder, "layer"):
71
+ return transformer.encoder.layer
72
+ if hasattr(transformer, "encoder") and hasattr(transformer.encoder, "layers"):
73
+ return transformer.encoder.layers
74
+ if hasattr(transformer, "layers"):
75
+ return transformer.layers
76
+ if hasattr(transformer, "auto_model"):
77
+ auto_model = transformer.auto_model
78
+ if hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layer"):
79
+ return auto_model.encoder.layer
80
+ if hasattr(auto_model, "layers"):
81
+ return auto_model.layers
82
+
83
+ raise ValueError("Could not find encoder layers in model")
84
+
85
+ def _get_num_layers(self) -> int:
86
+ """Get the number of layers in the model."""
87
+ return len(self._get_encoder_layers())
88
+
89
+ def compute_layer_importance(
90
+ self,
91
+ calibration_data: list[str] | None = None,
92
+ method: Literal["activation", "gradient", "cosine_similarity", "lm_loss"] = "activation",
93
+ num_samples: int = 1024,
94
+ ) -> dict[int, float]:
95
+ """
96
+ Compute importance scores for each layer.
97
+
98
+ Args:
99
+ calibration_data: Sentences to use for calibration.
100
+ method: Importance estimation method.
101
+ num_samples: Number of samples to use.
102
+
103
+ Returns:
104
+ Dictionary mapping layer index to importance score.
105
+ """
106
+ data = calibration_data or self.calibration_data
107
+ if data is None:
108
+ raise ValueError("Calibration data required for importance estimation")
109
+
110
+ estimator = ImportanceEstimator(self.model)
111
+
112
+ if method == "activation":
113
+ importance = estimator.activation_based_layer_importance(data[:num_samples])
114
+ elif method == "cosine_similarity":
115
+ importance = estimator.drop_layer_importance(data[:num_samples])
116
+ else:
117
+ importance = estimator.activation_based_layer_importance(data[:num_samples])
118
+
119
+ return importance
120
+
121
+ def select_layers_to_keep(
122
+ self,
123
+ num_layers: int | None = None,
124
+ ratio: float | None = None,
125
+ strategy: Literal["first", "last", "even", "importance"] = "importance",
126
+ importance_scores: dict[int, float] | None = None,
127
+ ) -> list[int]:
128
+ """
129
+ Select which layers to keep based on the strategy.
130
+
131
+ Args:
132
+ num_layers: Number of layers to keep.
133
+ ratio: Ratio of layers to keep (alternative to num_layers).
134
+ strategy: Layer selection strategy.
135
+ importance_scores: Precomputed importance scores (for 'importance' strategy).
136
+
137
+ Returns:
138
+ List of layer indices to keep.
139
+ """
140
+ total_layers = self._original_num_layers
141
+
142
+ if num_layers is None and ratio is not None:
143
+ num_layers = int(total_layers * ratio)
144
+ if num_layers is None:
145
+ raise ValueError("Either num_layers or ratio must be specified")
146
+
147
+ num_layers = min(num_layers, total_layers)
148
+
149
+ if strategy == "first":
150
+ return list(range(num_layers))
151
+
152
+ elif strategy == "last":
153
+ return list(range(total_layers - num_layers, total_layers))
154
+
155
+ elif strategy == "even":
156
+ # Evenly distribute layers
157
+ if num_layers == 1:
158
+ return [0]
159
+ step = (total_layers - 1) / (num_layers - 1)
160
+ return [int(round(i * step)) for i in range(num_layers)]
161
+
162
+ elif strategy == "importance":
163
+ if importance_scores is None:
164
+ if self.calibration_data is None:
165
+ # Fall back to even distribution
166
+ logger.warning("No importance scores or calibration data, using even distribution")
167
+ return self.select_layers_to_keep(num_layers=num_layers, strategy="even")
168
+ importance_scores = self.compute_layer_importance()
169
+
170
+ # Sort layers by importance and keep the most important ones
171
+ sorted_layers = sorted(importance_scores.items(), key=lambda x: x[1], reverse=True)
172
+ layers_to_keep = sorted([layer_idx for layer_idx, _ in sorted_layers[:num_layers]])
173
+ return layers_to_keep
174
+
175
+ else:
176
+ raise ValueError(f"Unknown strategy: {strategy}")
177
+
178
+ def prune(
179
+ self,
180
+ layers_to_keep: list[int] | None = None,
181
+ num_layers_to_keep: int | None = None,
182
+ layers_to_drop: list[int] | None = None,
183
+ layer_selection: Literal["first", "last", "even", "importance", "custom"] = "custom",
184
+ ) -> SentenceTransformer | PreTrainedModel:
185
+ """
186
+ Create a new model with only the specified layers.
187
+
188
+ Args:
189
+ layers_to_keep: Explicit list of layer indices to keep.
190
+ num_layers_to_keep: Number of layers to keep (alternative).
191
+ layers_to_drop: Layers to drop (alternative).
192
+ layer_selection: Strategy for selecting layers when using num_layers_to_keep.
193
+
194
+ Returns:
195
+ New model with reduced layers.
196
+ """
197
+ # Determine which layers to keep
198
+ if layers_to_keep is not None:
199
+ final_layers_to_keep = layers_to_keep
200
+ elif layers_to_drop is not None:
201
+ all_layers = set(range(self._original_num_layers))
202
+ final_layers_to_keep = sorted(all_layers - set(layers_to_drop))
203
+ elif num_layers_to_keep is not None:
204
+ if layer_selection == "custom":
205
+ layer_selection = "importance"
206
+ final_layers_to_keep = self.select_layers_to_keep(
207
+ num_layers=num_layers_to_keep,
208
+ strategy=layer_selection,
209
+ )
210
+ else:
211
+ raise ValueError("Must specify layers_to_keep, num_layers_to_keep, or layers_to_drop")
212
+
213
+ logger.info(f"Keeping layers: {final_layers_to_keep}")
214
+ logger.info(f"Reducing from {self._original_num_layers} to {len(final_layers_to_keep)} layers")
215
+
216
+ # Create a deep copy of the model
217
+ pruned_model = copy.deepcopy(self.model)
218
+
219
+ # Get the encoder layers
220
+ if isinstance(pruned_model, SentenceTransformer):
221
+ transformer = pruned_model._first_module()
222
+ if hasattr(transformer, "auto_model"):
223
+ auto_model = transformer.auto_model
224
+ else:
225
+ auto_model = transformer
226
+ else:
227
+ auto_model = pruned_model
228
+
229
+ # Find and replace the layers
230
+ encoder_layers = None
231
+ if hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layer"):
232
+ encoder_layers = auto_model.encoder.layer
233
+ parent = auto_model.encoder
234
+ attr_name = "layer"
235
+ elif hasattr(auto_model, "encoder") and hasattr(auto_model.encoder, "layers"):
236
+ encoder_layers = auto_model.encoder.layers
237
+ parent = auto_model.encoder
238
+ attr_name = "layers"
239
+ elif hasattr(auto_model, "layers"):
240
+ encoder_layers = auto_model.layers
241
+ parent = auto_model
242
+ attr_name = "layers"
243
+
244
+ if encoder_layers is None:
245
+ raise ValueError("Could not find encoder layers in model")
246
+
247
+ # Create new layer list with only kept layers
248
+ new_layers = nn.ModuleList([encoder_layers[i] for i in final_layers_to_keep])
249
+ setattr(parent, attr_name, new_layers)
250
+
251
+ # Update config if available
252
+ if hasattr(auto_model, "config"):
253
+ auto_model.config.num_hidden_layers = len(final_layers_to_keep)
254
+
255
+ # Log statistics
256
+ original_params = sum(p.numel() for p in self.model.parameters())
257
+ pruned_params = sum(p.numel() for p in pruned_model.parameters())
258
+ reduction = 1 - (pruned_params / original_params)
259
+ logger.info(f"Parameter reduction: {original_params:,} -> {pruned_params:,} ({reduction:.1%})")
260
+
261
+ return pruned_model
@@ -0,0 +1,365 @@
1
+ """Importance estimation for pruning."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Literal
7
+
8
+ import torch
9
+ from torch import nn
10
+ from tqdm import tqdm
11
+ from transformers import PreTrainedModel
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ImportanceEstimator:
19
+ """
20
+ Methods for estimating component importance.
21
+
22
+ Supports importance estimation for:
23
+ - Layers (for depth pruning)
24
+ - Hidden dimensions (for width pruning)
25
+ - Intermediate dimensions (for MLP pruning)
26
+ - Attention heads (for head pruning)
27
+
28
+ Example:
29
+ >>> estimator = ImportanceEstimator(model)
30
+ >>> layer_importance = estimator.activation_based_layer_importance(sentences)
31
+ >>> hidden_importance = estimator.hidden_dimension_importance(sentences)
32
+ """
33
+
34
+ def __init__(self, model: SentenceTransformer | PreTrainedModel):
35
+ """
36
+ Initialize the ImportanceEstimator.
37
+
38
+ Args:
39
+ model: The model to estimate importance for.
40
+ """
41
+ self.model = model
42
+ self.device = next(model.parameters()).device
43
+
44
+ def activation_based_layer_importance(
45
+ self,
46
+ sentences: list[str],
47
+ batch_size: int = 32,
48
+ ) -> dict[int, float]:
49
+ """
50
+ Estimate layer importance based on activation magnitudes.
51
+
52
+ Higher activation magnitude suggests higher importance.
53
+
54
+ Args:
55
+ sentences: Sentences to use for estimation.
56
+ batch_size: Batch size for inference.
57
+
58
+ Returns:
59
+ Dictionary mapping layer index to importance score.
60
+ """
61
+ self.model.eval()
62
+
63
+ # Get the transformer model
64
+ if isinstance(self.model, SentenceTransformer):
65
+ transformer = self.model._first_module()
66
+ if hasattr(transformer, "auto_model"):
67
+ auto_model = transformer.auto_model
68
+ else:
69
+ auto_model = transformer
70
+ tokenizer = self.model.tokenizer
71
+ else:
72
+ auto_model = self.model
73
+ tokenizer = None
74
+
75
+ # Get encoder layers
76
+ encoder_layers = self._get_encoder_layers(auto_model)
77
+ num_layers = len(encoder_layers)
78
+
79
+ # Track activations per layer
80
+ layer_activations = {i: 0.0 for i in range(num_layers)}
81
+ num_samples = 0
82
+
83
+ # Register hooks to capture activations
84
+ handles = []
85
+ activation_sums = {}
86
+
87
+ def make_hook(layer_idx):
88
+ def hook(module, input, output):
89
+ # Handle different output formats
90
+ if isinstance(output, tuple):
91
+ hidden_states = output[0]
92
+ else:
93
+ hidden_states = output
94
+ activation_sums[layer_idx] = activation_sums.get(layer_idx, 0) + hidden_states.abs().mean().item()
95
+ return hook
96
+
97
+ for i, layer in enumerate(encoder_layers):
98
+ handle = layer.register_forward_hook(make_hook(i))
99
+ handles.append(handle)
100
+
101
+ try:
102
+ with torch.no_grad():
103
+ for i in range(0, len(sentences), batch_size):
104
+ batch_sentences = sentences[i:i + batch_size]
105
+
106
+ if isinstance(self.model, SentenceTransformer):
107
+ self.model.encode(batch_sentences, convert_to_tensor=True)
108
+ else:
109
+ if tokenizer:
110
+ inputs = tokenizer(
111
+ batch_sentences,
112
+ padding=True,
113
+ truncation=True,
114
+ return_tensors="pt",
115
+ ).to(self.device)
116
+ self.model(**inputs)
117
+
118
+ num_samples += 1
119
+
120
+ # Average the activation sums
121
+ for layer_idx in range(num_layers):
122
+ layer_activations[layer_idx] = activation_sums.get(layer_idx, 0) / max(num_samples, 1)
123
+
124
+ finally:
125
+ for handle in handles:
126
+ handle.remove()
127
+
128
+ return layer_activations
129
+
130
+ def drop_layer_importance(
131
+ self,
132
+ sentences: list[str],
133
+ batch_size: int = 32,
134
+ ) -> dict[int, float]:
135
+ """
136
+ Estimate layer importance by measuring embedding similarity when layer is dropped.
137
+
138
+ Lower similarity when dropped = higher importance.
139
+
140
+ Args:
141
+ sentences: Sentences to use for estimation.
142
+ batch_size: Batch size for inference.
143
+
144
+ Returns:
145
+ Dictionary mapping layer index to importance score.
146
+ """
147
+ if not isinstance(self.model, SentenceTransformer):
148
+ raise ValueError("drop_layer_importance only works with SentenceTransformer models")
149
+
150
+ # Get reference embeddings
151
+ with torch.no_grad():
152
+ reference_embeddings = self.model.encode(
153
+ sentences[:min(len(sentences), 500)],
154
+ convert_to_tensor=True,
155
+ show_progress_bar=False,
156
+ )
157
+
158
+ # Get transformer model
159
+ transformer = self.model._first_module()
160
+ if hasattr(transformer, "auto_model"):
161
+ auto_model = transformer.auto_model
162
+ else:
163
+ auto_model = transformer
164
+
165
+ encoder_layers = self._get_encoder_layers(auto_model)
166
+ num_layers = len(encoder_layers)
167
+
168
+ importance = {}
169
+
170
+ for layer_idx in range(num_layers):
171
+ # Temporarily remove the layer
172
+ original_layer = encoder_layers[layer_idx]
173
+
174
+ # Replace with identity
175
+ encoder_layers[layer_idx] = nn.Identity()
176
+
177
+ try:
178
+ with torch.no_grad():
179
+ dropped_embeddings = self.model.encode(
180
+ sentences[:min(len(sentences), 500)],
181
+ convert_to_tensor=True,
182
+ show_progress_bar=False,
183
+ )
184
+
185
+ # Compute cosine similarity
186
+ similarity = torch.nn.functional.cosine_similarity(
187
+ reference_embeddings, dropped_embeddings, dim=1
188
+ ).mean().item()
189
+
190
+ # Lower similarity = higher importance
191
+ importance[layer_idx] = 1 - similarity
192
+
193
+ finally:
194
+ # Restore the layer
195
+ encoder_layers[layer_idx] = original_layer
196
+
197
+ return importance
198
+
199
+ def hidden_dimension_importance(
200
+ self,
201
+ sentences: list[str],
202
+ method: Literal["activation", "magnitude", "taylor"] = "activation",
203
+ batch_size: int = 32,
204
+ ) -> torch.Tensor:
205
+ """
206
+ Estimate importance of each hidden dimension.
207
+
208
+ Args:
209
+ sentences: Sentences to use for estimation.
210
+ method: Importance estimation method.
211
+ batch_size: Batch size for inference.
212
+
213
+ Returns:
214
+ Tensor of importance scores for each hidden dimension.
215
+ """
216
+ self.model.eval()
217
+
218
+ if isinstance(self.model, SentenceTransformer):
219
+ transformer = self.model._first_module()
220
+ if hasattr(transformer, "auto_model"):
221
+ auto_model = transformer.auto_model
222
+ else:
223
+ auto_model = transformer
224
+ else:
225
+ auto_model = self.model
226
+
227
+ hidden_size = auto_model.config.hidden_size
228
+ importance = torch.zeros(hidden_size, device=self.device)
229
+
230
+ if method == "activation":
231
+ # Track output activations
232
+ num_samples = 0
233
+
234
+ with torch.no_grad():
235
+ for i in range(0, len(sentences), batch_size):
236
+ batch_sentences = sentences[i:i + batch_size]
237
+
238
+ if isinstance(self.model, SentenceTransformer):
239
+ embeddings = self.model.encode(
240
+ batch_sentences,
241
+ convert_to_tensor=True,
242
+ show_progress_bar=False,
243
+ )
244
+ importance += embeddings.abs().mean(dim=0)
245
+ else:
246
+ # For regular models, would need tokenizer
247
+ pass
248
+
249
+ num_samples += 1
250
+
251
+ importance = importance / max(num_samples, 1)
252
+
253
+ elif method == "magnitude":
254
+ # Use weight magnitudes from final layer
255
+ encoder_layers = self._get_encoder_layers(auto_model)
256
+ last_layer = encoder_layers[-1]
257
+
258
+ # Sum up weight magnitudes
259
+ for name, param in last_layer.named_parameters():
260
+ if param.dim() >= 2:
261
+ # Aggregate across the hidden dimension
262
+ importance += param.abs().sum(dim=0)[:hidden_size] if param.size(-1) >= hidden_size else torch.zeros(hidden_size, device=self.device)
263
+
264
+ return importance
265
+
266
+ def intermediate_dimension_importance(
267
+ self,
268
+ sentences: list[str],
269
+ method: Literal["activation", "magnitude"] = "activation",
270
+ batch_size: int = 32,
271
+ ) -> torch.Tensor:
272
+ """
273
+ Estimate importance of each intermediate (MLP) dimension.
274
+
275
+ Args:
276
+ sentences: Sentences to use for estimation.
277
+ method: Importance estimation method.
278
+ batch_size: Batch size for inference.
279
+
280
+ Returns:
281
+ Tensor of importance scores for each intermediate dimension.
282
+ """
283
+ self.model.eval()
284
+
285
+ if isinstance(self.model, SentenceTransformer):
286
+ transformer = self.model._first_module()
287
+ if hasattr(transformer, "auto_model"):
288
+ auto_model = transformer.auto_model
289
+ else:
290
+ auto_model = transformer
291
+ else:
292
+ auto_model = self.model
293
+
294
+ intermediate_size = auto_model.config.intermediate_size
295
+ importance = torch.zeros(intermediate_size, device=self.device)
296
+
297
+ if method == "magnitude":
298
+ encoder_layers = self._get_encoder_layers(auto_model)
299
+
300
+ for layer in encoder_layers:
301
+ if hasattr(layer, "intermediate") and hasattr(layer.intermediate, "dense"):
302
+ weight = layer.intermediate.dense.weight
303
+ importance += weight.abs().mean(dim=1)
304
+
305
+ importance = importance / len(encoder_layers)
306
+
307
+ return importance
308
+
309
+ def attention_head_importance(
310
+ self,
311
+ sentences: list[str],
312
+ method: Literal["activation", "attention_entropy"] = "activation",
313
+ batch_size: int = 32,
314
+ ) -> torch.Tensor:
315
+ """
316
+ Estimate importance of each attention head.
317
+
318
+ Args:
319
+ sentences: Sentences to use for estimation.
320
+ method: Importance estimation method.
321
+ batch_size: Batch size for inference.
322
+
323
+ Returns:
324
+ Tensor of importance scores for each attention head.
325
+ """
326
+ self.model.eval()
327
+
328
+ if isinstance(self.model, SentenceTransformer):
329
+ transformer = self.model._first_module()
330
+ if hasattr(transformer, "auto_model"):
331
+ auto_model = transformer.auto_model
332
+ else:
333
+ auto_model = transformer
334
+ else:
335
+ auto_model = self.model
336
+
337
+ num_heads = auto_model.config.num_attention_heads
338
+ importance = torch.zeros(num_heads, device=self.device)
339
+
340
+ if method == "activation":
341
+ encoder_layers = self._get_encoder_layers(auto_model)
342
+ head_dim = auto_model.config.hidden_size // num_heads
343
+
344
+ for layer in encoder_layers:
345
+ if hasattr(layer, "attention") and hasattr(layer.attention, "self"):
346
+ self_attn = layer.attention.self
347
+ if hasattr(self_attn, "query"):
348
+ weight = self_attn.query.weight
349
+ # Reshape to [num_heads, head_dim, hidden_size]
350
+ weight_per_head = weight.view(num_heads, head_dim, -1)
351
+ importance += weight_per_head.abs().mean(dim=(1, 2))
352
+
353
+ importance = importance / len(encoder_layers)
354
+
355
+ return importance
356
+
357
+ def _get_encoder_layers(self, model) -> nn.ModuleList:
358
+ """Get encoder layers from the model."""
359
+ if hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
360
+ return model.encoder.layer
361
+ if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
362
+ return model.encoder.layers
363
+ if hasattr(model, "layers"):
364
+ return model.layers
365
+ raise ValueError("Could not find encoder layers")