distil-trainer 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,480 @@
1
+ """Width pruning for transformer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ import logging
7
+ from typing import Literal
8
+
9
+ import torch
10
+ from torch import nn
11
+ from transformers import PreTrainedModel
12
+
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ from distil_trainer.core.config import WidthPruningConfig
16
+ from distil_trainer.pruning.importance import ImportanceEstimator
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class WidthPruner:
22
+ """
23
+ Handles width reduction in transformer models.
24
+
25
+ Width pruning reduces:
26
+ - Hidden size (embedding dimension)
27
+ - Intermediate size (MLP hidden dimension)
28
+ - Number of attention heads
29
+
30
+ Example:
31
+ >>> config = WidthPruningConfig(
32
+ ... target_hidden_size=3072,
33
+ ... target_intermediate_size=9216,
34
+ ... )
35
+ >>> pruner = WidthPruner(model)
36
+ >>> pruned_model = pruner.prune(config)
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ model: SentenceTransformer | PreTrainedModel,
42
+ calibration_data: list[str] | None = None,
43
+ ):
44
+ """
45
+ Initialize the WidthPruner.
46
+
47
+ Args:
48
+ model: The model to prune.
49
+ calibration_data: Optional calibration data for importance estimation.
50
+ """
51
+ self.model = model
52
+ self.calibration_data = calibration_data
53
+ self._config = self._get_model_config()
54
+
55
+ def _get_model_config(self):
56
+ """Get the model configuration."""
57
+ if isinstance(self.model, SentenceTransformer):
58
+ transformer = self.model._first_module()
59
+ if hasattr(transformer, "auto_model"):
60
+ return transformer.auto_model.config
61
+ return transformer.config
62
+ return self.model.config
63
+
64
+ def compute_importance(
65
+ self,
66
+ dimension: Literal["hidden", "intermediate", "heads"],
67
+ calibration_data: list[str] | None = None,
68
+ method: Literal["activation", "magnitude", "taylor", "wanda"] = "activation",
69
+ num_samples: int = 1024,
70
+ ) -> torch.Tensor:
71
+ """
72
+ Compute importance scores for the specified dimension.
73
+
74
+ Args:
75
+ dimension: Which dimension to compute importance for.
76
+ calibration_data: Sentences to use for calibration.
77
+ method: Importance estimation method.
78
+ num_samples: Number of samples to use.
79
+
80
+ Returns:
81
+ Tensor of importance scores.
82
+ """
83
+ data = calibration_data or self.calibration_data
84
+
85
+ estimator = ImportanceEstimator(self.model)
86
+
87
+ if dimension == "hidden":
88
+ return estimator.hidden_dimension_importance(data[:num_samples], method=method)
89
+ elif dimension == "intermediate":
90
+ return estimator.intermediate_dimension_importance(data[:num_samples], method=method)
91
+ elif dimension == "heads":
92
+ return estimator.attention_head_importance(data[:num_samples], method=method)
93
+ else:
94
+ raise ValueError(f"Unknown dimension: {dimension}")
95
+
96
+ def prune(
97
+ self,
98
+ config: WidthPruningConfig,
99
+ calibration_data: list[str] | None = None,
100
+ ) -> SentenceTransformer | PreTrainedModel:
101
+ """
102
+ Apply width pruning based on configuration.
103
+
104
+ Args:
105
+ config: Width pruning configuration.
106
+ calibration_data: Optional calibration data.
107
+
108
+ Returns:
109
+ New model with reduced width.
110
+ """
111
+ data = calibration_data or self.calibration_data
112
+
113
+ # Determine target dimensions
114
+ current_hidden = getattr(self._config, "hidden_size", None)
115
+ current_intermediate = getattr(self._config, "intermediate_size", None)
116
+ current_heads = getattr(self._config, "num_attention_heads", None)
117
+
118
+ target_hidden = config.target_hidden_size
119
+ target_intermediate = config.target_intermediate_size
120
+ target_heads = config.target_num_attention_heads
121
+
122
+ # Use ratios if absolute values not specified
123
+ if target_hidden is None and config.hidden_size_ratio is not None:
124
+ target_hidden = int(current_hidden * config.hidden_size_ratio)
125
+ if target_intermediate is None and config.intermediate_size_ratio is not None:
126
+ target_intermediate = int(current_intermediate * config.intermediate_size_ratio)
127
+ if target_heads is None and config.attention_head_ratio is not None:
128
+ target_heads = int(current_heads * config.attention_head_ratio)
129
+
130
+ logger.info(f"Width pruning targets:")
131
+ if target_hidden:
132
+ logger.info(f" Hidden size: {current_hidden} -> {target_hidden}")
133
+ if target_intermediate:
134
+ logger.info(f" Intermediate size: {current_intermediate} -> {target_intermediate}")
135
+ if target_heads:
136
+ logger.info(f" Attention heads: {current_heads} -> {target_heads}")
137
+
138
+ # Create a deep copy of the model
139
+ pruned_model = copy.deepcopy(self.model)
140
+
141
+ # Compute importance scores if we have calibration data
142
+ hidden_importance = None
143
+ intermediate_importance = None
144
+ head_importance = None
145
+
146
+ if data is not None:
147
+ estimator = ImportanceEstimator(self.model)
148
+ if target_hidden and target_hidden < current_hidden:
149
+ hidden_importance = estimator.hidden_dimension_importance(
150
+ data[:config.calibration_samples],
151
+ method=config.importance_method,
152
+ )
153
+ if target_intermediate and target_intermediate < current_intermediate:
154
+ intermediate_importance = estimator.intermediate_dimension_importance(
155
+ data[:config.calibration_samples],
156
+ method=config.importance_method,
157
+ )
158
+ if target_heads and target_heads < current_heads:
159
+ head_importance = estimator.attention_head_importance(
160
+ data[:config.calibration_samples],
161
+ method=config.importance_method,
162
+ )
163
+
164
+ # Apply pruning
165
+ if target_hidden and target_hidden < current_hidden:
166
+ pruned_model = self._prune_hidden_size(
167
+ pruned_model, target_hidden, hidden_importance
168
+ )
169
+
170
+ if target_intermediate and target_intermediate < current_intermediate:
171
+ pruned_model = self._prune_intermediate_size(
172
+ pruned_model, target_intermediate, intermediate_importance
173
+ )
174
+
175
+ if target_heads and target_heads < current_heads:
176
+ pruned_model = self._prune_attention_heads(
177
+ pruned_model, target_heads, head_importance
178
+ )
179
+
180
+ # Log statistics
181
+ original_params = sum(p.numel() for p in self.model.parameters())
182
+ pruned_params = sum(p.numel() for p in pruned_model.parameters())
183
+ reduction = 1 - (pruned_params / original_params)
184
+ logger.info(f"Parameter reduction: {original_params:,} -> {pruned_params:,} ({reduction:.1%})")
185
+
186
+ return pruned_model
187
+
188
+ def _prune_hidden_size(
189
+ self,
190
+ model: SentenceTransformer | PreTrainedModel,
191
+ target_size: int,
192
+ importance: torch.Tensor | None = None,
193
+ ) -> SentenceTransformer | PreTrainedModel:
194
+ """Prune the hidden dimension size."""
195
+ logger.info(f"Pruning hidden size to {target_size}")
196
+
197
+ # Get auto_model
198
+ if isinstance(model, SentenceTransformer):
199
+ transformer = model._first_module()
200
+ if hasattr(transformer, "auto_model"):
201
+ auto_model = transformer.auto_model
202
+ else:
203
+ auto_model = transformer
204
+ else:
205
+ auto_model = model
206
+
207
+ config = auto_model.config
208
+ current_size = config.hidden_size
209
+
210
+ # Determine which dimensions to keep
211
+ if importance is not None:
212
+ # Keep highest importance dimensions
213
+ _, indices = torch.topk(importance, target_size)
214
+ keep_indices = indices.sort().values
215
+ else:
216
+ # Keep first N dimensions
217
+ keep_indices = torch.arange(target_size)
218
+
219
+ # Update embedding layer
220
+ if hasattr(auto_model, "embeddings"):
221
+ embeddings = auto_model.embeddings
222
+ if hasattr(embeddings, "word_embeddings"):
223
+ old_weight = embeddings.word_embeddings.weight.data
224
+ new_weight = old_weight[:, keep_indices]
225
+ embeddings.word_embeddings = nn.Embedding(
226
+ old_weight.size(0), target_size
227
+ )
228
+ embeddings.word_embeddings.weight.data = new_weight
229
+
230
+ if hasattr(embeddings, "position_embeddings"):
231
+ old_weight = embeddings.position_embeddings.weight.data
232
+ new_weight = old_weight[:, keep_indices]
233
+ embeddings.position_embeddings = nn.Embedding(
234
+ old_weight.size(0), target_size
235
+ )
236
+ embeddings.position_embeddings.weight.data = new_weight
237
+
238
+ if hasattr(embeddings, "LayerNorm"):
239
+ old_ln = embeddings.LayerNorm
240
+ embeddings.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
241
+ embeddings.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
242
+ embeddings.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
243
+
244
+ # Update encoder layers
245
+ encoder_layers = self._get_encoder_layers(auto_model)
246
+ for layer in encoder_layers:
247
+ self._prune_layer_hidden(layer, keep_indices, target_size)
248
+
249
+ # Update config
250
+ config.hidden_size = target_size
251
+
252
+ return model
253
+
254
+ def _prune_intermediate_size(
255
+ self,
256
+ model: SentenceTransformer | PreTrainedModel,
257
+ target_size: int,
258
+ importance: torch.Tensor | None = None,
259
+ ) -> SentenceTransformer | PreTrainedModel:
260
+ """Prune the MLP intermediate dimension."""
261
+ logger.info(f"Pruning intermediate size to {target_size}")
262
+
263
+ # Get auto_model
264
+ if isinstance(model, SentenceTransformer):
265
+ transformer = model._first_module()
266
+ if hasattr(transformer, "auto_model"):
267
+ auto_model = transformer.auto_model
268
+ else:
269
+ auto_model = transformer
270
+ else:
271
+ auto_model = model
272
+
273
+ config = auto_model.config
274
+ current_size = config.intermediate_size
275
+
276
+ # Determine which dimensions to keep
277
+ if importance is not None:
278
+ _, indices = torch.topk(importance, target_size)
279
+ keep_indices = indices.sort().values
280
+ else:
281
+ keep_indices = torch.arange(target_size)
282
+
283
+ # Update encoder layers
284
+ encoder_layers = self._get_encoder_layers(auto_model)
285
+ for layer in encoder_layers:
286
+ self._prune_layer_intermediate(layer, keep_indices, target_size)
287
+
288
+ # Update config
289
+ config.intermediate_size = target_size
290
+
291
+ return model
292
+
293
+ def _prune_attention_heads(
294
+ self,
295
+ model: SentenceTransformer | PreTrainedModel,
296
+ target_heads: int,
297
+ importance: torch.Tensor | None = None,
298
+ ) -> SentenceTransformer | PreTrainedModel:
299
+ """Prune the number of attention heads."""
300
+ logger.info(f"Pruning attention heads to {target_heads}")
301
+
302
+ # Get auto_model
303
+ if isinstance(model, SentenceTransformer):
304
+ transformer = model._first_module()
305
+ if hasattr(transformer, "auto_model"):
306
+ auto_model = transformer.auto_model
307
+ else:
308
+ auto_model = transformer
309
+ else:
310
+ auto_model = model
311
+
312
+ config = auto_model.config
313
+ current_heads = config.num_attention_heads
314
+ head_dim = config.hidden_size // current_heads
315
+
316
+ # Determine which heads to keep
317
+ if importance is not None:
318
+ _, indices = torch.topk(importance, target_heads)
319
+ keep_head_indices = indices.sort().values
320
+ else:
321
+ keep_head_indices = torch.arange(target_heads)
322
+
323
+ # Update encoder layers
324
+ encoder_layers = self._get_encoder_layers(auto_model)
325
+ for layer in encoder_layers:
326
+ self._prune_layer_heads(layer, keep_head_indices, head_dim)
327
+
328
+ # Update config
329
+ config.num_attention_heads = target_heads
330
+
331
+ return model
332
+
333
+ def _get_encoder_layers(self, model) -> nn.ModuleList:
334
+ """Get encoder layers from the model."""
335
+ if hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
336
+ return model.encoder.layer
337
+ if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
338
+ return model.encoder.layers
339
+ if hasattr(model, "layers"):
340
+ return model.layers
341
+ raise ValueError("Could not find encoder layers")
342
+
343
+ def _prune_layer_hidden(
344
+ self,
345
+ layer: nn.Module,
346
+ keep_indices: torch.Tensor,
347
+ target_size: int,
348
+ ) -> None:
349
+ """Prune hidden dimension in a single layer."""
350
+ # This is model-specific and may need customization for different architectures
351
+ # Here we provide a generic implementation for BERT-like models
352
+
353
+ # Prune attention
354
+ if hasattr(layer, "attention"):
355
+ attention = layer.attention
356
+ if hasattr(attention, "self"):
357
+ self_attn = attention.self
358
+ # Query, Key, Value projections
359
+ for proj_name in ["query", "key", "value"]:
360
+ if hasattr(self_attn, proj_name):
361
+ proj = getattr(self_attn, proj_name)
362
+ new_weight = proj.weight.data[:, keep_indices]
363
+ new_proj = nn.Linear(target_size, proj.out_features, bias=proj.bias is not None)
364
+ new_proj.weight.data = new_weight
365
+ if proj.bias is not None:
366
+ new_proj.bias.data = proj.bias.data
367
+ setattr(self_attn, proj_name, new_proj)
368
+
369
+ if hasattr(attention, "output"):
370
+ output = attention.output
371
+ if hasattr(output, "dense"):
372
+ old_dense = output.dense
373
+ new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
374
+ new_dense.weight.data = old_dense.weight.data[keep_indices]
375
+ if old_dense.bias is not None:
376
+ new_dense.bias.data = old_dense.bias.data[keep_indices]
377
+ output.dense = new_dense
378
+
379
+ if hasattr(output, "LayerNorm"):
380
+ old_ln = output.LayerNorm
381
+ output.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
382
+ output.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
383
+ output.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
384
+
385
+ # Prune intermediate/MLP
386
+ if hasattr(layer, "intermediate"):
387
+ intermediate = layer.intermediate
388
+ if hasattr(intermediate, "dense"):
389
+ old_dense = intermediate.dense
390
+ new_dense = nn.Linear(target_size, old_dense.out_features, bias=old_dense.bias is not None)
391
+ new_dense.weight.data = old_dense.weight.data[:, keep_indices]
392
+ if old_dense.bias is not None:
393
+ new_dense.bias.data = old_dense.bias.data
394
+ intermediate.dense = new_dense
395
+
396
+ if hasattr(layer, "output"):
397
+ output = layer.output
398
+ if hasattr(output, "dense"):
399
+ old_dense = output.dense
400
+ new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
401
+ new_dense.weight.data = old_dense.weight.data[keep_indices]
402
+ if old_dense.bias is not None:
403
+ new_dense.bias.data = old_dense.bias.data[keep_indices]
404
+ output.dense = new_dense
405
+
406
+ if hasattr(output, "LayerNorm"):
407
+ old_ln = output.LayerNorm
408
+ output.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
409
+ output.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
410
+ output.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
411
+
412
+ def _prune_layer_intermediate(
413
+ self,
414
+ layer: nn.Module,
415
+ keep_indices: torch.Tensor,
416
+ target_size: int,
417
+ ) -> None:
418
+ """Prune intermediate dimension in a single layer."""
419
+ if hasattr(layer, "intermediate"):
420
+ intermediate = layer.intermediate
421
+ if hasattr(intermediate, "dense"):
422
+ old_dense = intermediate.dense
423
+ new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
424
+ new_dense.weight.data = old_dense.weight.data[keep_indices]
425
+ if old_dense.bias is not None:
426
+ new_dense.bias.data = old_dense.bias.data[keep_indices]
427
+ intermediate.dense = new_dense
428
+
429
+ if hasattr(layer, "output"):
430
+ output = layer.output
431
+ if hasattr(output, "dense"):
432
+ old_dense = output.dense
433
+ new_dense = nn.Linear(target_size, old_dense.out_features, bias=old_dense.bias is not None)
434
+ new_dense.weight.data = old_dense.weight.data[:, keep_indices]
435
+ if old_dense.bias is not None:
436
+ new_dense.bias.data = old_dense.bias.data
437
+ output.dense = new_dense
438
+
439
+ def _prune_layer_heads(
440
+ self,
441
+ layer: nn.Module,
442
+ keep_head_indices: torch.Tensor,
443
+ head_dim: int,
444
+ ) -> None:
445
+ """Prune attention heads in a single layer."""
446
+ if not hasattr(layer, "attention"):
447
+ return
448
+
449
+ attention = layer.attention
450
+ if not hasattr(attention, "self"):
451
+ return
452
+
453
+ self_attn = attention.self
454
+ target_heads = len(keep_head_indices)
455
+ new_head_size = target_heads * head_dim
456
+
457
+ # Compute dimension indices for keeping specific heads
458
+ dim_indices = []
459
+ for head_idx in keep_head_indices:
460
+ start = head_idx * head_dim
461
+ end = start + head_dim
462
+ dim_indices.extend(range(start, end))
463
+ dim_indices = torch.tensor(dim_indices)
464
+
465
+ # Prune Query, Key, Value projections
466
+ for proj_name in ["query", "key", "value"]:
467
+ if hasattr(self_attn, proj_name):
468
+ proj = getattr(self_attn, proj_name)
469
+ new_weight = proj.weight.data[dim_indices]
470
+ new_proj = nn.Linear(proj.in_features, new_head_size, bias=proj.bias is not None)
471
+ new_proj.weight.data = new_weight
472
+ if proj.bias is not None:
473
+ new_proj.bias.data = proj.bias.data[dim_indices]
474
+ setattr(self_attn, proj_name, new_proj)
475
+
476
+ # Update num_attention_heads attribute
477
+ if hasattr(self_attn, "num_attention_heads"):
478
+ self_attn.num_attention_heads = target_heads
479
+ if hasattr(self_attn, "all_head_size"):
480
+ self_attn.all_head_size = new_head_size