distil-trainer 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- distil_trainer/__init__.py +31 -0
- distil_trainer/core/__init__.py +23 -0
- distil_trainer/core/callbacks.py +188 -0
- distil_trainer/core/config.py +358 -0
- distil_trainer/core/trainer.py +843 -0
- distil_trainer/data/__init__.py +19 -0
- distil_trainer/data/collators.py +240 -0
- distil_trainer/data/datamodule.py +191 -0
- distil_trainer/data/datasets.py +245 -0
- distil_trainer/data/loaders.py +163 -0
- distil_trainer/distillation/__init__.py +21 -0
- distil_trainer/distillation/losses.py +345 -0
- distil_trainer/distillation/multilingual.py +285 -0
- distil_trainer/distillation/strategies.py +211 -0
- distil_trainer/evaluation/__init__.py +19 -0
- distil_trainer/evaluation/benchmarks.py +86 -0
- distil_trainer/evaluation/evaluators.py +343 -0
- distil_trainer/evaluation/metrics.py +75 -0
- distil_trainer/models/__init__.py +5 -0
- distil_trainer/models/layers.py +115 -0
- distil_trainer/pruning/__init__.py +13 -0
- distil_trainer/pruning/combined_pruning.py +122 -0
- distil_trainer/pruning/depth_pruning.py +261 -0
- distil_trainer/pruning/importance.py +365 -0
- distil_trainer/pruning/width_pruning.py +480 -0
- distil_trainer-0.1.10.dist-info/METADATA +443 -0
- distil_trainer-0.1.10.dist-info/RECORD +29 -0
- distil_trainer-0.1.10.dist-info/WHEEL +4 -0
- distil_trainer-0.1.10.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,480 @@
|
|
|
1
|
+
"""Width pruning for transformer models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import copy
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import nn
|
|
11
|
+
from transformers import PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
from distil_trainer.core.config import WidthPruningConfig
|
|
16
|
+
from distil_trainer.pruning.importance import ImportanceEstimator
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class WidthPruner:
|
|
22
|
+
"""
|
|
23
|
+
Handles width reduction in transformer models.
|
|
24
|
+
|
|
25
|
+
Width pruning reduces:
|
|
26
|
+
- Hidden size (embedding dimension)
|
|
27
|
+
- Intermediate size (MLP hidden dimension)
|
|
28
|
+
- Number of attention heads
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> config = WidthPruningConfig(
|
|
32
|
+
... target_hidden_size=3072,
|
|
33
|
+
... target_intermediate_size=9216,
|
|
34
|
+
... )
|
|
35
|
+
>>> pruner = WidthPruner(model)
|
|
36
|
+
>>> pruned_model = pruner.prune(config)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
42
|
+
calibration_data: list[str] | None = None,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize the WidthPruner.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model: The model to prune.
|
|
49
|
+
calibration_data: Optional calibration data for importance estimation.
|
|
50
|
+
"""
|
|
51
|
+
self.model = model
|
|
52
|
+
self.calibration_data = calibration_data
|
|
53
|
+
self._config = self._get_model_config()
|
|
54
|
+
|
|
55
|
+
def _get_model_config(self):
|
|
56
|
+
"""Get the model configuration."""
|
|
57
|
+
if isinstance(self.model, SentenceTransformer):
|
|
58
|
+
transformer = self.model._first_module()
|
|
59
|
+
if hasattr(transformer, "auto_model"):
|
|
60
|
+
return transformer.auto_model.config
|
|
61
|
+
return transformer.config
|
|
62
|
+
return self.model.config
|
|
63
|
+
|
|
64
|
+
def compute_importance(
|
|
65
|
+
self,
|
|
66
|
+
dimension: Literal["hidden", "intermediate", "heads"],
|
|
67
|
+
calibration_data: list[str] | None = None,
|
|
68
|
+
method: Literal["activation", "magnitude", "taylor", "wanda"] = "activation",
|
|
69
|
+
num_samples: int = 1024,
|
|
70
|
+
) -> torch.Tensor:
|
|
71
|
+
"""
|
|
72
|
+
Compute importance scores for the specified dimension.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
dimension: Which dimension to compute importance for.
|
|
76
|
+
calibration_data: Sentences to use for calibration.
|
|
77
|
+
method: Importance estimation method.
|
|
78
|
+
num_samples: Number of samples to use.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Tensor of importance scores.
|
|
82
|
+
"""
|
|
83
|
+
data = calibration_data or self.calibration_data
|
|
84
|
+
|
|
85
|
+
estimator = ImportanceEstimator(self.model)
|
|
86
|
+
|
|
87
|
+
if dimension == "hidden":
|
|
88
|
+
return estimator.hidden_dimension_importance(data[:num_samples], method=method)
|
|
89
|
+
elif dimension == "intermediate":
|
|
90
|
+
return estimator.intermediate_dimension_importance(data[:num_samples], method=method)
|
|
91
|
+
elif dimension == "heads":
|
|
92
|
+
return estimator.attention_head_importance(data[:num_samples], method=method)
|
|
93
|
+
else:
|
|
94
|
+
raise ValueError(f"Unknown dimension: {dimension}")
|
|
95
|
+
|
|
96
|
+
def prune(
|
|
97
|
+
self,
|
|
98
|
+
config: WidthPruningConfig,
|
|
99
|
+
calibration_data: list[str] | None = None,
|
|
100
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
101
|
+
"""
|
|
102
|
+
Apply width pruning based on configuration.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
config: Width pruning configuration.
|
|
106
|
+
calibration_data: Optional calibration data.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
New model with reduced width.
|
|
110
|
+
"""
|
|
111
|
+
data = calibration_data or self.calibration_data
|
|
112
|
+
|
|
113
|
+
# Determine target dimensions
|
|
114
|
+
current_hidden = getattr(self._config, "hidden_size", None)
|
|
115
|
+
current_intermediate = getattr(self._config, "intermediate_size", None)
|
|
116
|
+
current_heads = getattr(self._config, "num_attention_heads", None)
|
|
117
|
+
|
|
118
|
+
target_hidden = config.target_hidden_size
|
|
119
|
+
target_intermediate = config.target_intermediate_size
|
|
120
|
+
target_heads = config.target_num_attention_heads
|
|
121
|
+
|
|
122
|
+
# Use ratios if absolute values not specified
|
|
123
|
+
if target_hidden is None and config.hidden_size_ratio is not None:
|
|
124
|
+
target_hidden = int(current_hidden * config.hidden_size_ratio)
|
|
125
|
+
if target_intermediate is None and config.intermediate_size_ratio is not None:
|
|
126
|
+
target_intermediate = int(current_intermediate * config.intermediate_size_ratio)
|
|
127
|
+
if target_heads is None and config.attention_head_ratio is not None:
|
|
128
|
+
target_heads = int(current_heads * config.attention_head_ratio)
|
|
129
|
+
|
|
130
|
+
logger.info(f"Width pruning targets:")
|
|
131
|
+
if target_hidden:
|
|
132
|
+
logger.info(f" Hidden size: {current_hidden} -> {target_hidden}")
|
|
133
|
+
if target_intermediate:
|
|
134
|
+
logger.info(f" Intermediate size: {current_intermediate} -> {target_intermediate}")
|
|
135
|
+
if target_heads:
|
|
136
|
+
logger.info(f" Attention heads: {current_heads} -> {target_heads}")
|
|
137
|
+
|
|
138
|
+
# Create a deep copy of the model
|
|
139
|
+
pruned_model = copy.deepcopy(self.model)
|
|
140
|
+
|
|
141
|
+
# Compute importance scores if we have calibration data
|
|
142
|
+
hidden_importance = None
|
|
143
|
+
intermediate_importance = None
|
|
144
|
+
head_importance = None
|
|
145
|
+
|
|
146
|
+
if data is not None:
|
|
147
|
+
estimator = ImportanceEstimator(self.model)
|
|
148
|
+
if target_hidden and target_hidden < current_hidden:
|
|
149
|
+
hidden_importance = estimator.hidden_dimension_importance(
|
|
150
|
+
data[:config.calibration_samples],
|
|
151
|
+
method=config.importance_method,
|
|
152
|
+
)
|
|
153
|
+
if target_intermediate and target_intermediate < current_intermediate:
|
|
154
|
+
intermediate_importance = estimator.intermediate_dimension_importance(
|
|
155
|
+
data[:config.calibration_samples],
|
|
156
|
+
method=config.importance_method,
|
|
157
|
+
)
|
|
158
|
+
if target_heads and target_heads < current_heads:
|
|
159
|
+
head_importance = estimator.attention_head_importance(
|
|
160
|
+
data[:config.calibration_samples],
|
|
161
|
+
method=config.importance_method,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Apply pruning
|
|
165
|
+
if target_hidden and target_hidden < current_hidden:
|
|
166
|
+
pruned_model = self._prune_hidden_size(
|
|
167
|
+
pruned_model, target_hidden, hidden_importance
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if target_intermediate and target_intermediate < current_intermediate:
|
|
171
|
+
pruned_model = self._prune_intermediate_size(
|
|
172
|
+
pruned_model, target_intermediate, intermediate_importance
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if target_heads and target_heads < current_heads:
|
|
176
|
+
pruned_model = self._prune_attention_heads(
|
|
177
|
+
pruned_model, target_heads, head_importance
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Log statistics
|
|
181
|
+
original_params = sum(p.numel() for p in self.model.parameters())
|
|
182
|
+
pruned_params = sum(p.numel() for p in pruned_model.parameters())
|
|
183
|
+
reduction = 1 - (pruned_params / original_params)
|
|
184
|
+
logger.info(f"Parameter reduction: {original_params:,} -> {pruned_params:,} ({reduction:.1%})")
|
|
185
|
+
|
|
186
|
+
return pruned_model
|
|
187
|
+
|
|
188
|
+
def _prune_hidden_size(
|
|
189
|
+
self,
|
|
190
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
191
|
+
target_size: int,
|
|
192
|
+
importance: torch.Tensor | None = None,
|
|
193
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
194
|
+
"""Prune the hidden dimension size."""
|
|
195
|
+
logger.info(f"Pruning hidden size to {target_size}")
|
|
196
|
+
|
|
197
|
+
# Get auto_model
|
|
198
|
+
if isinstance(model, SentenceTransformer):
|
|
199
|
+
transformer = model._first_module()
|
|
200
|
+
if hasattr(transformer, "auto_model"):
|
|
201
|
+
auto_model = transformer.auto_model
|
|
202
|
+
else:
|
|
203
|
+
auto_model = transformer
|
|
204
|
+
else:
|
|
205
|
+
auto_model = model
|
|
206
|
+
|
|
207
|
+
config = auto_model.config
|
|
208
|
+
current_size = config.hidden_size
|
|
209
|
+
|
|
210
|
+
# Determine which dimensions to keep
|
|
211
|
+
if importance is not None:
|
|
212
|
+
# Keep highest importance dimensions
|
|
213
|
+
_, indices = torch.topk(importance, target_size)
|
|
214
|
+
keep_indices = indices.sort().values
|
|
215
|
+
else:
|
|
216
|
+
# Keep first N dimensions
|
|
217
|
+
keep_indices = torch.arange(target_size)
|
|
218
|
+
|
|
219
|
+
# Update embedding layer
|
|
220
|
+
if hasattr(auto_model, "embeddings"):
|
|
221
|
+
embeddings = auto_model.embeddings
|
|
222
|
+
if hasattr(embeddings, "word_embeddings"):
|
|
223
|
+
old_weight = embeddings.word_embeddings.weight.data
|
|
224
|
+
new_weight = old_weight[:, keep_indices]
|
|
225
|
+
embeddings.word_embeddings = nn.Embedding(
|
|
226
|
+
old_weight.size(0), target_size
|
|
227
|
+
)
|
|
228
|
+
embeddings.word_embeddings.weight.data = new_weight
|
|
229
|
+
|
|
230
|
+
if hasattr(embeddings, "position_embeddings"):
|
|
231
|
+
old_weight = embeddings.position_embeddings.weight.data
|
|
232
|
+
new_weight = old_weight[:, keep_indices]
|
|
233
|
+
embeddings.position_embeddings = nn.Embedding(
|
|
234
|
+
old_weight.size(0), target_size
|
|
235
|
+
)
|
|
236
|
+
embeddings.position_embeddings.weight.data = new_weight
|
|
237
|
+
|
|
238
|
+
if hasattr(embeddings, "LayerNorm"):
|
|
239
|
+
old_ln = embeddings.LayerNorm
|
|
240
|
+
embeddings.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
|
|
241
|
+
embeddings.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
|
|
242
|
+
embeddings.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
|
|
243
|
+
|
|
244
|
+
# Update encoder layers
|
|
245
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
246
|
+
for layer in encoder_layers:
|
|
247
|
+
self._prune_layer_hidden(layer, keep_indices, target_size)
|
|
248
|
+
|
|
249
|
+
# Update config
|
|
250
|
+
config.hidden_size = target_size
|
|
251
|
+
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
def _prune_intermediate_size(
|
|
255
|
+
self,
|
|
256
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
257
|
+
target_size: int,
|
|
258
|
+
importance: torch.Tensor | None = None,
|
|
259
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
260
|
+
"""Prune the MLP intermediate dimension."""
|
|
261
|
+
logger.info(f"Pruning intermediate size to {target_size}")
|
|
262
|
+
|
|
263
|
+
# Get auto_model
|
|
264
|
+
if isinstance(model, SentenceTransformer):
|
|
265
|
+
transformer = model._first_module()
|
|
266
|
+
if hasattr(transformer, "auto_model"):
|
|
267
|
+
auto_model = transformer.auto_model
|
|
268
|
+
else:
|
|
269
|
+
auto_model = transformer
|
|
270
|
+
else:
|
|
271
|
+
auto_model = model
|
|
272
|
+
|
|
273
|
+
config = auto_model.config
|
|
274
|
+
current_size = config.intermediate_size
|
|
275
|
+
|
|
276
|
+
# Determine which dimensions to keep
|
|
277
|
+
if importance is not None:
|
|
278
|
+
_, indices = torch.topk(importance, target_size)
|
|
279
|
+
keep_indices = indices.sort().values
|
|
280
|
+
else:
|
|
281
|
+
keep_indices = torch.arange(target_size)
|
|
282
|
+
|
|
283
|
+
# Update encoder layers
|
|
284
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
285
|
+
for layer in encoder_layers:
|
|
286
|
+
self._prune_layer_intermediate(layer, keep_indices, target_size)
|
|
287
|
+
|
|
288
|
+
# Update config
|
|
289
|
+
config.intermediate_size = target_size
|
|
290
|
+
|
|
291
|
+
return model
|
|
292
|
+
|
|
293
|
+
def _prune_attention_heads(
|
|
294
|
+
self,
|
|
295
|
+
model: SentenceTransformer | PreTrainedModel,
|
|
296
|
+
target_heads: int,
|
|
297
|
+
importance: torch.Tensor | None = None,
|
|
298
|
+
) -> SentenceTransformer | PreTrainedModel:
|
|
299
|
+
"""Prune the number of attention heads."""
|
|
300
|
+
logger.info(f"Pruning attention heads to {target_heads}")
|
|
301
|
+
|
|
302
|
+
# Get auto_model
|
|
303
|
+
if isinstance(model, SentenceTransformer):
|
|
304
|
+
transformer = model._first_module()
|
|
305
|
+
if hasattr(transformer, "auto_model"):
|
|
306
|
+
auto_model = transformer.auto_model
|
|
307
|
+
else:
|
|
308
|
+
auto_model = transformer
|
|
309
|
+
else:
|
|
310
|
+
auto_model = model
|
|
311
|
+
|
|
312
|
+
config = auto_model.config
|
|
313
|
+
current_heads = config.num_attention_heads
|
|
314
|
+
head_dim = config.hidden_size // current_heads
|
|
315
|
+
|
|
316
|
+
# Determine which heads to keep
|
|
317
|
+
if importance is not None:
|
|
318
|
+
_, indices = torch.topk(importance, target_heads)
|
|
319
|
+
keep_head_indices = indices.sort().values
|
|
320
|
+
else:
|
|
321
|
+
keep_head_indices = torch.arange(target_heads)
|
|
322
|
+
|
|
323
|
+
# Update encoder layers
|
|
324
|
+
encoder_layers = self._get_encoder_layers(auto_model)
|
|
325
|
+
for layer in encoder_layers:
|
|
326
|
+
self._prune_layer_heads(layer, keep_head_indices, head_dim)
|
|
327
|
+
|
|
328
|
+
# Update config
|
|
329
|
+
config.num_attention_heads = target_heads
|
|
330
|
+
|
|
331
|
+
return model
|
|
332
|
+
|
|
333
|
+
def _get_encoder_layers(self, model) -> nn.ModuleList:
|
|
334
|
+
"""Get encoder layers from the model."""
|
|
335
|
+
if hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
|
|
336
|
+
return model.encoder.layer
|
|
337
|
+
if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
|
|
338
|
+
return model.encoder.layers
|
|
339
|
+
if hasattr(model, "layers"):
|
|
340
|
+
return model.layers
|
|
341
|
+
raise ValueError("Could not find encoder layers")
|
|
342
|
+
|
|
343
|
+
def _prune_layer_hidden(
|
|
344
|
+
self,
|
|
345
|
+
layer: nn.Module,
|
|
346
|
+
keep_indices: torch.Tensor,
|
|
347
|
+
target_size: int,
|
|
348
|
+
) -> None:
|
|
349
|
+
"""Prune hidden dimension in a single layer."""
|
|
350
|
+
# This is model-specific and may need customization for different architectures
|
|
351
|
+
# Here we provide a generic implementation for BERT-like models
|
|
352
|
+
|
|
353
|
+
# Prune attention
|
|
354
|
+
if hasattr(layer, "attention"):
|
|
355
|
+
attention = layer.attention
|
|
356
|
+
if hasattr(attention, "self"):
|
|
357
|
+
self_attn = attention.self
|
|
358
|
+
# Query, Key, Value projections
|
|
359
|
+
for proj_name in ["query", "key", "value"]:
|
|
360
|
+
if hasattr(self_attn, proj_name):
|
|
361
|
+
proj = getattr(self_attn, proj_name)
|
|
362
|
+
new_weight = proj.weight.data[:, keep_indices]
|
|
363
|
+
new_proj = nn.Linear(target_size, proj.out_features, bias=proj.bias is not None)
|
|
364
|
+
new_proj.weight.data = new_weight
|
|
365
|
+
if proj.bias is not None:
|
|
366
|
+
new_proj.bias.data = proj.bias.data
|
|
367
|
+
setattr(self_attn, proj_name, new_proj)
|
|
368
|
+
|
|
369
|
+
if hasattr(attention, "output"):
|
|
370
|
+
output = attention.output
|
|
371
|
+
if hasattr(output, "dense"):
|
|
372
|
+
old_dense = output.dense
|
|
373
|
+
new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
|
|
374
|
+
new_dense.weight.data = old_dense.weight.data[keep_indices]
|
|
375
|
+
if old_dense.bias is not None:
|
|
376
|
+
new_dense.bias.data = old_dense.bias.data[keep_indices]
|
|
377
|
+
output.dense = new_dense
|
|
378
|
+
|
|
379
|
+
if hasattr(output, "LayerNorm"):
|
|
380
|
+
old_ln = output.LayerNorm
|
|
381
|
+
output.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
|
|
382
|
+
output.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
|
|
383
|
+
output.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
|
|
384
|
+
|
|
385
|
+
# Prune intermediate/MLP
|
|
386
|
+
if hasattr(layer, "intermediate"):
|
|
387
|
+
intermediate = layer.intermediate
|
|
388
|
+
if hasattr(intermediate, "dense"):
|
|
389
|
+
old_dense = intermediate.dense
|
|
390
|
+
new_dense = nn.Linear(target_size, old_dense.out_features, bias=old_dense.bias is not None)
|
|
391
|
+
new_dense.weight.data = old_dense.weight.data[:, keep_indices]
|
|
392
|
+
if old_dense.bias is not None:
|
|
393
|
+
new_dense.bias.data = old_dense.bias.data
|
|
394
|
+
intermediate.dense = new_dense
|
|
395
|
+
|
|
396
|
+
if hasattr(layer, "output"):
|
|
397
|
+
output = layer.output
|
|
398
|
+
if hasattr(output, "dense"):
|
|
399
|
+
old_dense = output.dense
|
|
400
|
+
new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
|
|
401
|
+
new_dense.weight.data = old_dense.weight.data[keep_indices]
|
|
402
|
+
if old_dense.bias is not None:
|
|
403
|
+
new_dense.bias.data = old_dense.bias.data[keep_indices]
|
|
404
|
+
output.dense = new_dense
|
|
405
|
+
|
|
406
|
+
if hasattr(output, "LayerNorm"):
|
|
407
|
+
old_ln = output.LayerNorm
|
|
408
|
+
output.LayerNorm = nn.LayerNorm(target_size, eps=old_ln.eps)
|
|
409
|
+
output.LayerNorm.weight.data = old_ln.weight.data[keep_indices]
|
|
410
|
+
output.LayerNorm.bias.data = old_ln.bias.data[keep_indices]
|
|
411
|
+
|
|
412
|
+
def _prune_layer_intermediate(
|
|
413
|
+
self,
|
|
414
|
+
layer: nn.Module,
|
|
415
|
+
keep_indices: torch.Tensor,
|
|
416
|
+
target_size: int,
|
|
417
|
+
) -> None:
|
|
418
|
+
"""Prune intermediate dimension in a single layer."""
|
|
419
|
+
if hasattr(layer, "intermediate"):
|
|
420
|
+
intermediate = layer.intermediate
|
|
421
|
+
if hasattr(intermediate, "dense"):
|
|
422
|
+
old_dense = intermediate.dense
|
|
423
|
+
new_dense = nn.Linear(old_dense.in_features, target_size, bias=old_dense.bias is not None)
|
|
424
|
+
new_dense.weight.data = old_dense.weight.data[keep_indices]
|
|
425
|
+
if old_dense.bias is not None:
|
|
426
|
+
new_dense.bias.data = old_dense.bias.data[keep_indices]
|
|
427
|
+
intermediate.dense = new_dense
|
|
428
|
+
|
|
429
|
+
if hasattr(layer, "output"):
|
|
430
|
+
output = layer.output
|
|
431
|
+
if hasattr(output, "dense"):
|
|
432
|
+
old_dense = output.dense
|
|
433
|
+
new_dense = nn.Linear(target_size, old_dense.out_features, bias=old_dense.bias is not None)
|
|
434
|
+
new_dense.weight.data = old_dense.weight.data[:, keep_indices]
|
|
435
|
+
if old_dense.bias is not None:
|
|
436
|
+
new_dense.bias.data = old_dense.bias.data
|
|
437
|
+
output.dense = new_dense
|
|
438
|
+
|
|
439
|
+
def _prune_layer_heads(
|
|
440
|
+
self,
|
|
441
|
+
layer: nn.Module,
|
|
442
|
+
keep_head_indices: torch.Tensor,
|
|
443
|
+
head_dim: int,
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Prune attention heads in a single layer."""
|
|
446
|
+
if not hasattr(layer, "attention"):
|
|
447
|
+
return
|
|
448
|
+
|
|
449
|
+
attention = layer.attention
|
|
450
|
+
if not hasattr(attention, "self"):
|
|
451
|
+
return
|
|
452
|
+
|
|
453
|
+
self_attn = attention.self
|
|
454
|
+
target_heads = len(keep_head_indices)
|
|
455
|
+
new_head_size = target_heads * head_dim
|
|
456
|
+
|
|
457
|
+
# Compute dimension indices for keeping specific heads
|
|
458
|
+
dim_indices = []
|
|
459
|
+
for head_idx in keep_head_indices:
|
|
460
|
+
start = head_idx * head_dim
|
|
461
|
+
end = start + head_dim
|
|
462
|
+
dim_indices.extend(range(start, end))
|
|
463
|
+
dim_indices = torch.tensor(dim_indices)
|
|
464
|
+
|
|
465
|
+
# Prune Query, Key, Value projections
|
|
466
|
+
for proj_name in ["query", "key", "value"]:
|
|
467
|
+
if hasattr(self_attn, proj_name):
|
|
468
|
+
proj = getattr(self_attn, proj_name)
|
|
469
|
+
new_weight = proj.weight.data[dim_indices]
|
|
470
|
+
new_proj = nn.Linear(proj.in_features, new_head_size, bias=proj.bias is not None)
|
|
471
|
+
new_proj.weight.data = new_weight
|
|
472
|
+
if proj.bias is not None:
|
|
473
|
+
new_proj.bias.data = proj.bias.data[dim_indices]
|
|
474
|
+
setattr(self_attn, proj_name, new_proj)
|
|
475
|
+
|
|
476
|
+
# Update num_attention_heads attribute
|
|
477
|
+
if hasattr(self_attn, "num_attention_heads"):
|
|
478
|
+
self_attn.num_attention_heads = target_heads
|
|
479
|
+
if hasattr(self_attn, "all_head_size"):
|
|
480
|
+
self_attn.all_head_size = new_head_size
|