fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,313 @@
1
+ import logging
2
+ from typing import List, Tuple, cast
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ from tqdm.auto import tqdm
8
+ from transformers import LlamaForCausalLM, PreTrainedModel
9
+
10
+ from fusion_bench import timeit_context
11
+
12
+ from .data import get_loaders
13
+ from .layerwrapper import WrappedGPT
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def find_layers(module, layers=[nn.Linear], name=""):
19
+ """
20
+ Recursively find the layers of a certain type in a module.
21
+
22
+ Args:
23
+ module (nn.Module): PyTorch module.
24
+ layers (list): List of layer types to find.
25
+ name (str): Name of the module.
26
+
27
+ Returns:
28
+ dict: Dictionary of layers of the given type(s) within the module.
29
+ """
30
+ if type(module) in layers:
31
+ return {name: module}
32
+ res = {}
33
+ for name1, child in module.named_children():
34
+ res.update(
35
+ find_layers(
36
+ child, layers=layers, name=name + "." + name1 if name != "" else name1
37
+ )
38
+ )
39
+ return res
40
+
41
+
42
+ def check_sparsity(model):
43
+ """
44
+ Check the sparsity of the model by counting the number of zero weights.
45
+
46
+ Args:
47
+ model (PreTrainedModel): The model to check sparsity for.
48
+
49
+ Returns:
50
+ float: The sparsity ratio of the model.
51
+ """
52
+ use_cache = model.config.use_cache
53
+ model.config.use_cache = False
54
+
55
+ layers = model.model.layers
56
+ count = 0
57
+ total_params = 0
58
+ for i in range(len(layers)):
59
+ layer = layers[i]
60
+ subset = find_layers(layer)
61
+
62
+ sub_count = 0
63
+ sub_params = 0
64
+ for name in subset:
65
+ W = subset[name].weight.data
66
+ count += (W == 0).sum().item()
67
+ total_params += W.numel()
68
+
69
+ sub_count += (W == 0).sum().item()
70
+ sub_params += W.numel()
71
+
72
+ print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")
73
+
74
+ model.config.use_cache = use_cache
75
+ return float(count) / total_params
76
+
77
+
78
+ def prepare_calibration_input(
79
+ model: PreTrainedModel,
80
+ dataloader: List[Tuple[Tensor, Tensor]],
81
+ device: torch.device,
82
+ ):
83
+ """
84
+ Prepare the calibration input for the model by collecting input to the first layer.
85
+
86
+ Args:
87
+ model (PreTrainedModel): The model to prepare calibration input for.
88
+ dataloader (List[Tuple[Tensor, Tensor]]): The dataloader to use for calibration.
89
+ device (torch.device): The device to use for calibration.
90
+
91
+ Returns:
92
+ Tuple[Tensor, Tensor, Tensor, Tensor]: The prepared input, output, attention mask, and position IDs.
93
+ """
94
+ use_cache = model.config.use_cache
95
+ model.config.use_cache = False
96
+ layers = model.model.layers
97
+
98
+ # dev = model.hf_device_map["model.embed_tokens"]
99
+ if hasattr(model, "hf_device_map") and "model.embed_tokens" in model.hf_device_map:
100
+ device = model.hf_device_map["model.embed_tokens"]
101
+
102
+ dtype = next(iter(model.parameters())).dtype
103
+ # ? what if n_samples > 128
104
+ inps = torch.zeros(
105
+ (128, model.seqlen, model.config.hidden_size),
106
+ dtype=dtype,
107
+ device=device,
108
+ requires_grad=False,
109
+ )
110
+ cache = {"i": 0, "attention_mask": None, "position_ids": None, 'position_embeddings': None}
111
+
112
+ class Catcher(nn.Module):
113
+ def __init__(self, module):
114
+ super().__init__()
115
+ self.module = module
116
+
117
+ def forward(self, inp, **kwargs):
118
+ inps[cache["i"]] = inp
119
+ cache["i"] += 1
120
+ # collect attention_mask and position_ids
121
+ cache["attention_mask"] = kwargs["attention_mask"]
122
+ cache["position_ids"] = kwargs["position_ids"]
123
+ if "position_embeddings" in kwargs:
124
+ cache["position_embeddings"] = kwargs["position_embeddings"]
125
+ else:
126
+ cache["position_embeddings"] = None
127
+ raise ValueError # stop the forward pass
128
+
129
+ layers[0] = Catcher(layers[0])
130
+ for batch in dataloader:
131
+ try:
132
+ model(batch[0].to(device))
133
+ except ValueError:
134
+ pass
135
+ layers[0] = layers[0].module
136
+
137
+ outs = torch.zeros_like(inps)
138
+ attention_mask = cache["attention_mask"]
139
+ position_ids = cache["position_ids"]
140
+ position_embeddings = cache["position_embeddings"]
141
+ model.config.use_cache = use_cache
142
+
143
+ return inps, outs, attention_mask, position_ids, position_embeddings
144
+
145
+
146
+ def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
147
+ """
148
+ Return the mask and current sparsity given an alpha value.
149
+
150
+ Args:
151
+ alpha (float): The alpha value.
152
+ sort_res (Tensor): The sorted results.
153
+ W_metric (Tensor): The weight metric.
154
+ tmp_metric (Tensor): The temporary metric.
155
+ sum_before (Tensor): The sum before the alpha value.
156
+
157
+ Returns:
158
+ Tuple[Tensor, float]: The mask and current sparsity.
159
+ """
160
+ thres_cumsum = sum_before * alpha
161
+ sort_mask = tmp_metric <= thres_cumsum.reshape((-1, 1))
162
+ thres = torch.gather(
163
+ sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True) - 1
164
+ )
165
+ W_mask = W_metric <= thres
166
+ cur_sparsity = (W_mask == True).sum() / W_mask.numel()
167
+ return W_mask, cur_sparsity
168
+
169
+
170
+ def llama_prune_wanda_(
171
+ args,
172
+ model: LlamaForCausalLM,
173
+ tokenizer,
174
+ device=torch.device("cuda:0"),
175
+ prune_n=0,
176
+ prune_m=0,
177
+ ):
178
+ """
179
+ Perform Wanda pruning on a Llama model.
180
+
181
+ Args:
182
+ args: The arguments for pruning.
183
+ model (LlamaForCausalLM): The model to prune.
184
+ tokenizer: The tokenizer to use for calibration.
185
+ device (torch.device, optional): The device to use for pruning. Defaults to torch.device("cuda:0").
186
+ prune_n (int, optional): The number of elements to prune in each block. Defaults to 0.
187
+ prune_m (int, optional): The size of each block. Defaults to 0.
188
+ """
189
+ use_cache = model.config.use_cache
190
+ model.config.use_cache = False
191
+
192
+ with timeit_context("loading calibdation data"):
193
+ dataloader, _ = get_loaders(
194
+ "c4",
195
+ nsamples=args.nsamples,
196
+ seed=args.seed,
197
+ seqlen=model.seqlen,
198
+ tokenizer=tokenizer,
199
+ )
200
+
201
+ with torch.no_grad():
202
+ # collect input to the first layer
203
+ inps, outs, attention_mask, position_ids = prepare_calibration_input(
204
+ model, dataloader, device
205
+ )
206
+
207
+ layers = model.model.layers
208
+ for i in range(len(layers)):
209
+ layer = layers[i]
210
+ subset = find_layers(layer)
211
+
212
+ if (
213
+ hasattr(model, "hf_device_map")
214
+ and f"model.layers.{i}" in model.hf_device_map
215
+ ): ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
216
+ dev = model.hf_device_map[f"model.layers.{i}"]
217
+ inps, outs, attention_mask, position_ids = (
218
+ inps.to(dev),
219
+ outs.to(dev),
220
+ attention_mask.to(dev) if attention_mask is not None else None,
221
+ position_ids.to(dev) if position_ids is not None else None,
222
+ )
223
+
224
+ wrapped_layers = {}
225
+ for name in subset:
226
+ wrapped_layers[name] = WrappedGPT(subset[name])
227
+
228
+ def add_batch(name):
229
+ def tmp(_, inp, out):
230
+ cast(WrappedGPT, wrapped_layers[name]).add_batch(inp[0].data, out.data)
231
+
232
+ return tmp
233
+
234
+ handles = []
235
+ for name in wrapped_layers:
236
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
237
+ for j in range(args.nsamples):
238
+ with torch.no_grad():
239
+ outs[j] = layer(
240
+ inps[j].unsqueeze(0),
241
+ attention_mask=attention_mask,
242
+ position_ids=position_ids,
243
+ )[0]
244
+ for h in handles:
245
+ h.remove()
246
+
247
+ for name in subset:
248
+ print(f"pruning layer {i} name {name}")
249
+ W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
250
+ wrapped_layers[name].scaler_row.reshape((1, -1))
251
+ )
252
+
253
+ W_mask = (
254
+ torch.zeros_like(W_metric) == 1
255
+ ) ## initialize a mask to be all False
256
+ if prune_n != 0:
257
+ # structured n:m sparsity
258
+ for ii in range(W_metric.shape[1]):
259
+ if ii % prune_m == 0:
260
+ tmp = W_metric[:, ii : (ii + prune_m)].float()
261
+ W_mask.scatter_(
262
+ 1,
263
+ ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
264
+ True,
265
+ )
266
+ else:
267
+ sort_res = torch.sort(W_metric, dim=-1, stable=True)
268
+
269
+ if args.use_variant:
270
+ # wanda variant
271
+ tmp_metric = torch.cumsum(sort_res[0], dim=1)
272
+ sum_before = W_metric.sum(dim=1)
273
+
274
+ alpha = 0.4
275
+ alpha_hist = [0.0, 0.8]
276
+ W_mask, cur_sparsity = return_given_alpha(
277
+ alpha, sort_res, W_metric, tmp_metric, sum_before
278
+ )
279
+ while (torch.abs(cur_sparsity - args.sparsity_ratio) > 0.001) and (
280
+ alpha_hist[1] - alpha_hist[0] >= 0.001
281
+ ):
282
+ if cur_sparsity > args.sparsity_ratio:
283
+ alpha_new = (alpha + alpha_hist[0]) / 2.0
284
+ alpha_hist[1] = alpha
285
+ else:
286
+ alpha_new = (alpha + alpha_hist[1]) / 2.0
287
+ alpha_hist[0] = alpha
288
+
289
+ alpha = alpha_new
290
+ W_mask, cur_sparsity = return_given_alpha(
291
+ alpha, sort_res, W_metric, tmp_metric, sum_before
292
+ )
293
+ print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
294
+ else:
295
+ # unstructured pruning
296
+ indices = sort_res[1][
297
+ :, : int(W_metric.shape[1] * args.sparsity_ratio)
298
+ ]
299
+ W_mask.scatter_(1, indices, True)
300
+
301
+ subset[name].weight.data[W_mask] = 0 ## set weights to zero
302
+
303
+ for j in range(args.nsamples):
304
+ with torch.no_grad():
305
+ outs[j] = layer(
306
+ inps[j].unsqueeze(0),
307
+ attention_mask=attention_mask,
308
+ position_ids=position_ids,
309
+ )[0]
310
+ inps, outs = outs, inps
311
+
312
+ model.config.use_cache = use_cache
313
+ torch.cuda.empty_cache()
@@ -0,0 +1,41 @@
1
+ from typing import List, Union
2
+
3
+ import numpy as np
4
+
5
+
6
+ def layer_load_balance_score(
7
+ number_of_tokens_dispatched: Union[List[int], np.ndarray],
8
+ number_of_experts: int,
9
+ ) -> float:
10
+ """
11
+ Calculate the load balance score for one layer of the MoE model.
12
+
13
+ Args:
14
+ number_of_tokens_dispatched: List[int]
15
+ number_of_experts: int
16
+
17
+ Returns:
18
+ float: The load balance score
19
+ """
20
+ if len(number_of_tokens_dispatched) != number_of_experts:
21
+ raise ValueError(
22
+ f"The number of tokens dispatched ({len(number_of_tokens_dispatched)}) must match the number of experts ({number_of_experts})"
23
+ )
24
+
25
+ number_of_tokens_dispatched = np.array(number_of_tokens_dispatched)
26
+ mu = number_of_tokens_dispatched.mean()
27
+ sigma = np.sqrt(((number_of_tokens_dispatched - mu) ** 2).mean())
28
+ return sigma / mu
29
+
30
+
31
+ def model_load_balance_score(layer_load_balance_scores: List[float]) -> float:
32
+ """
33
+ Calculate the load balance score for the whole model.
34
+
35
+ Args:
36
+ layer_load_balance_scores: List[float]
37
+
38
+ Returns:
39
+ float: The load balance score
40
+ """
41
+ return np.array(layer_load_balance_scores).mean()
@@ -1,5 +1,6 @@
1
1
  # flake8: noqa F401
2
2
  from .llama_magnitude_prune import MagnitudePruningForLlama
3
3
  from .llama_random_prune import RandomPruningForLlama
4
+ from .llama_sparsegpt_prune import SparseGPTPruningForLlama
4
5
  from .llama_wanda_prune import WandaPruningForLlama
5
6
  from .magnitude_diff_pruning import MagnitudeDiffPruningAlgorithm
@@ -0,0 +1,223 @@
1
+ import logging
2
+ from typing import Dict, Optional
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+ from tqdm.auto import tqdm
7
+ from transformers import LlamaForCausalLM
8
+
9
+ from fusion_bench import BaseAlgorithm
10
+ from fusion_bench.method.pruning.prune_utils import (
11
+ PruningType,
12
+ compute_sparsity,
13
+ find_linear_layers,
14
+ semistructured_magnitude_prune_,
15
+ unstructured_magnitude_prune_,
16
+ )
17
+ from fusion_bench.method.pruning.sparsegpt_utils import SparseGPT
18
+ from fusion_bench.method.pruning.wanda_utils.data import get_loaders
19
+ from fusion_bench.method.pruning.wanda_utils.prune import prepare_calibration_input
20
+ from fusion_bench.mixins import SimpleProfilerMixin
21
+ from fusion_bench.modelpool import CausalLMPool
22
+ from fusion_bench.utils import timeit_context
23
+ from fusion_bench.utils.cache_utils import cache_to_disk
24
+
25
+ log = logging.getLogger(__name__)
26
+
27
+
28
+ class SparseGPTPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ nsamples: int,
33
+ seed: int,
34
+ use_variant: bool,
35
+ prune_type: PruningType,
36
+ device: str,
37
+ dtype: str,
38
+ sparsity_ratio: float,
39
+ n: int,
40
+ m: int,
41
+ model_save_path: Optional[str] = None,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ Initialize the SparseGPTPruningForLlama class.
46
+
47
+ Args:
48
+ nsamples (int): Number of samples for calibration.
49
+ seed (int): Random seed.
50
+ use_variant (bool): Whether to use a variant of the pruning method.
51
+ prune_type (PruningType): Type of pruning to perform.
52
+ device (str): Device to use for computation.
53
+ dtype (str): Data type to use for computation.
54
+ sparsity_ratio (float): Sparsity ratio for pruning.
55
+ n (int): Number of elements to keep in semi-structured pruning.
56
+ m (int): Number of elements in a group for semi-structured pruning.
57
+ model_save_path (Optional[str]): Path to save the pruned model.
58
+ **kwargs: Additional arguments.
59
+ """
60
+ super().__init__(**kwargs)
61
+ self.nsamples = nsamples
62
+ self.seed = seed
63
+ self.use_variant = use_variant
64
+ self.prune_type = prune_type
65
+ self.device = device
66
+ self.dtype = dtype
67
+ self.sparsity_ratio = sparsity_ratio
68
+ self.n = n
69
+ self.m = m
70
+ self.model_save_path = model_save_path
71
+
72
+ def run(self, modelpool: CausalLMPool):
73
+ # load pre-trained model or the first model in the pool
74
+ with self.profile("load_model"):
75
+ model = modelpool.load_pretrained_or_first_model()
76
+ model.seqlen = model.config.max_position_embeddings
77
+ tokenizer = modelpool.load_tokenizer(use_fast=False)
78
+
79
+ if not isinstance(model, (LlamaForCausalLM,)):
80
+ log.warning(f"Model type {type(model)} may not supported.")
81
+
82
+ inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
83
+ model, tokenizer
84
+ )
85
+
86
+ self.prune_using_calibration_data_(
87
+ model,
88
+ inps=inps,
89
+ outs=outs,
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ )
93
+
94
+ if self.model_save_path is not None:
95
+ with timeit_context(f"Saving pruned model to {self.model_save_path}"):
96
+ tokenizer.save_pretrained(self.model_save_path)
97
+ model.save_pretrained(self.model_save_path)
98
+ return model
99
+
100
+ def _prepare_calibration_data(self, model, tokenizer):
101
+ """
102
+ Prepare calibration data for pruning.
103
+
104
+ Args:
105
+ model (LlamaForCausalLM): Model to be pruned.
106
+ tokenizer: Tokenizer for the model.
107
+
108
+ Returns:
109
+ Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
110
+ """
111
+ with timeit_context("loading calibration data"):
112
+ dataloader, _ = get_loaders(
113
+ "c4",
114
+ nsamples=self.nsamples,
115
+ seed=self.seed,
116
+ seqlen=model.seqlen,
117
+ tokenizer=tokenizer,
118
+ )
119
+
120
+ with torch.no_grad():
121
+ # collect input to the first layer
122
+ inps, outs, attention_mask, position_ids = prepare_calibration_input(
123
+ model, dataloader, self.device
124
+ )
125
+ return inps, outs, attention_mask, position_ids
126
+
127
+ def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
128
+ """
129
+ Prepare calibration data for pruning with caching.
130
+
131
+ Args:
132
+ model (LlamaForCausalLM): Model to be pruned.
133
+ tokenizer: Tokenizer for the model.
134
+
135
+ Returns:
136
+ Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
137
+ """
138
+
139
+ @cache_to_disk(
140
+ f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
141
+ )
142
+ def _prepare_calibration_data(model, tokenizer):
143
+ return self._prepare_calibration_data(model, tokenizer)
144
+
145
+ return _prepare_calibration_data(model, tokenizer)
146
+
147
+ @torch.no_grad()
148
+ def prune_using_calibration_data_(
149
+ self,
150
+ model: LlamaForCausalLM,
151
+ *,
152
+ inps,
153
+ outs,
154
+ attention_mask,
155
+ position_ids,
156
+ ):
157
+ layers = model.model.layers
158
+ for layer_indx, layer in tqdm(
159
+ enumerate(layers),
160
+ "Pruning Layers",
161
+ total=len(layers),
162
+ dynamic_ncols=True,
163
+ ):
164
+ layer = layers[layer_indx]
165
+ if f"model.layers.{layer_indx}" in model.hf_device_map:
166
+ dev = model.hf_device_map[f"model.layers.{layer_indx}"]
167
+ print(f"layer {layer_indx} device {dev}")
168
+ inps, outs, attention_mask, position_ids = (
169
+ inps.to(dev),
170
+ outs.to(dev),
171
+ attention_mask.to(dev),
172
+ position_ids.to(dev),
173
+ )
174
+
175
+ subset = find_linear_layers(layer, layers=[nn.Linear])
176
+
177
+ gpts: Dict[str, SparseGPT] = {}
178
+ for name in subset:
179
+ gpts[name] = SparseGPT(subset[name])
180
+
181
+ def add_batch(name):
182
+ def tmp(_, inp, out):
183
+ gpts[name].add_batch(inp[0].data, out.data)
184
+
185
+ return tmp
186
+
187
+ handles = []
188
+ for name in gpts:
189
+ handles.append(subset[name].register_forward_hook(add_batch(name)))
190
+
191
+ for j in range(self.nsamples):
192
+ outs[j] = layer(
193
+ inps[j].unsqueeze(0),
194
+ attention_mask=attention_mask,
195
+ position_ids=position_ids,
196
+ )[0]
197
+ for h in handles:
198
+ h.remove()
199
+
200
+ for name in gpts:
201
+ print(layer_indx, name)
202
+ print("Pruning ...")
203
+
204
+ gpts[name].fasterprune(
205
+ self.sparsity_ratio,
206
+ prune_n=self.n,
207
+ prune_m=self.m,
208
+ percdamp=0.01,
209
+ blocksize=128,
210
+ )
211
+ gpts[name].free()
212
+
213
+ for j in range(self.nsamples):
214
+ outs[j] = layer(
215
+ inps[j].unsqueeze(0),
216
+ attention_mask=attention_mask,
217
+ position_ids=position_ids,
218
+ )[0]
219
+
220
+ layers[layer_indx] = layer
221
+ torch.cuda.empty_cache()
222
+
223
+ inps, outs = outs, inps
@@ -0,0 +1 @@
1
+ from .sparsegpt import SparseGPT