fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +14 -5
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/pylogger.py +28 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import List, Tuple, cast
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from tqdm.auto import tqdm
|
|
8
|
+
from transformers import LlamaForCausalLM, PreTrainedModel
|
|
9
|
+
|
|
10
|
+
from fusion_bench import timeit_context
|
|
11
|
+
|
|
12
|
+
from .data import get_loaders
|
|
13
|
+
from .layerwrapper import WrappedGPT
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def find_layers(module, layers=[nn.Linear], name=""):
|
|
19
|
+
"""
|
|
20
|
+
Recursively find the layers of a certain type in a module.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
module (nn.Module): PyTorch module.
|
|
24
|
+
layers (list): List of layer types to find.
|
|
25
|
+
name (str): Name of the module.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
dict: Dictionary of layers of the given type(s) within the module.
|
|
29
|
+
"""
|
|
30
|
+
if type(module) in layers:
|
|
31
|
+
return {name: module}
|
|
32
|
+
res = {}
|
|
33
|
+
for name1, child in module.named_children():
|
|
34
|
+
res.update(
|
|
35
|
+
find_layers(
|
|
36
|
+
child, layers=layers, name=name + "." + name1 if name != "" else name1
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
return res
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def check_sparsity(model):
|
|
43
|
+
"""
|
|
44
|
+
Check the sparsity of the model by counting the number of zero weights.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model (PreTrainedModel): The model to check sparsity for.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
float: The sparsity ratio of the model.
|
|
51
|
+
"""
|
|
52
|
+
use_cache = model.config.use_cache
|
|
53
|
+
model.config.use_cache = False
|
|
54
|
+
|
|
55
|
+
layers = model.model.layers
|
|
56
|
+
count = 0
|
|
57
|
+
total_params = 0
|
|
58
|
+
for i in range(len(layers)):
|
|
59
|
+
layer = layers[i]
|
|
60
|
+
subset = find_layers(layer)
|
|
61
|
+
|
|
62
|
+
sub_count = 0
|
|
63
|
+
sub_params = 0
|
|
64
|
+
for name in subset:
|
|
65
|
+
W = subset[name].weight.data
|
|
66
|
+
count += (W == 0).sum().item()
|
|
67
|
+
total_params += W.numel()
|
|
68
|
+
|
|
69
|
+
sub_count += (W == 0).sum().item()
|
|
70
|
+
sub_params += W.numel()
|
|
71
|
+
|
|
72
|
+
print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}")
|
|
73
|
+
|
|
74
|
+
model.config.use_cache = use_cache
|
|
75
|
+
return float(count) / total_params
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def prepare_calibration_input(
|
|
79
|
+
model: PreTrainedModel,
|
|
80
|
+
dataloader: List[Tuple[Tensor, Tensor]],
|
|
81
|
+
device: torch.device,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Prepare the calibration input for the model by collecting input to the first layer.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
model (PreTrainedModel): The model to prepare calibration input for.
|
|
88
|
+
dataloader (List[Tuple[Tensor, Tensor]]): The dataloader to use for calibration.
|
|
89
|
+
device (torch.device): The device to use for calibration.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Tuple[Tensor, Tensor, Tensor, Tensor]: The prepared input, output, attention mask, and position IDs.
|
|
93
|
+
"""
|
|
94
|
+
use_cache = model.config.use_cache
|
|
95
|
+
model.config.use_cache = False
|
|
96
|
+
layers = model.model.layers
|
|
97
|
+
|
|
98
|
+
# dev = model.hf_device_map["model.embed_tokens"]
|
|
99
|
+
if hasattr(model, "hf_device_map") and "model.embed_tokens" in model.hf_device_map:
|
|
100
|
+
device = model.hf_device_map["model.embed_tokens"]
|
|
101
|
+
|
|
102
|
+
dtype = next(iter(model.parameters())).dtype
|
|
103
|
+
# ? what if n_samples > 128
|
|
104
|
+
inps = torch.zeros(
|
|
105
|
+
(128, model.seqlen, model.config.hidden_size),
|
|
106
|
+
dtype=dtype,
|
|
107
|
+
device=device,
|
|
108
|
+
requires_grad=False,
|
|
109
|
+
)
|
|
110
|
+
cache = {"i": 0, "attention_mask": None, "position_ids": None, 'position_embeddings': None}
|
|
111
|
+
|
|
112
|
+
class Catcher(nn.Module):
|
|
113
|
+
def __init__(self, module):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.module = module
|
|
116
|
+
|
|
117
|
+
def forward(self, inp, **kwargs):
|
|
118
|
+
inps[cache["i"]] = inp
|
|
119
|
+
cache["i"] += 1
|
|
120
|
+
# collect attention_mask and position_ids
|
|
121
|
+
cache["attention_mask"] = kwargs["attention_mask"]
|
|
122
|
+
cache["position_ids"] = kwargs["position_ids"]
|
|
123
|
+
if "position_embeddings" in kwargs:
|
|
124
|
+
cache["position_embeddings"] = kwargs["position_embeddings"]
|
|
125
|
+
else:
|
|
126
|
+
cache["position_embeddings"] = None
|
|
127
|
+
raise ValueError # stop the forward pass
|
|
128
|
+
|
|
129
|
+
layers[0] = Catcher(layers[0])
|
|
130
|
+
for batch in dataloader:
|
|
131
|
+
try:
|
|
132
|
+
model(batch[0].to(device))
|
|
133
|
+
except ValueError:
|
|
134
|
+
pass
|
|
135
|
+
layers[0] = layers[0].module
|
|
136
|
+
|
|
137
|
+
outs = torch.zeros_like(inps)
|
|
138
|
+
attention_mask = cache["attention_mask"]
|
|
139
|
+
position_ids = cache["position_ids"]
|
|
140
|
+
position_embeddings = cache["position_embeddings"]
|
|
141
|
+
model.config.use_cache = use_cache
|
|
142
|
+
|
|
143
|
+
return inps, outs, attention_mask, position_ids, position_embeddings
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before):
|
|
147
|
+
"""
|
|
148
|
+
Return the mask and current sparsity given an alpha value.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
alpha (float): The alpha value.
|
|
152
|
+
sort_res (Tensor): The sorted results.
|
|
153
|
+
W_metric (Tensor): The weight metric.
|
|
154
|
+
tmp_metric (Tensor): The temporary metric.
|
|
155
|
+
sum_before (Tensor): The sum before the alpha value.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Tuple[Tensor, float]: The mask and current sparsity.
|
|
159
|
+
"""
|
|
160
|
+
thres_cumsum = sum_before * alpha
|
|
161
|
+
sort_mask = tmp_metric <= thres_cumsum.reshape((-1, 1))
|
|
162
|
+
thres = torch.gather(
|
|
163
|
+
sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True) - 1
|
|
164
|
+
)
|
|
165
|
+
W_mask = W_metric <= thres
|
|
166
|
+
cur_sparsity = (W_mask == True).sum() / W_mask.numel()
|
|
167
|
+
return W_mask, cur_sparsity
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def llama_prune_wanda_(
|
|
171
|
+
args,
|
|
172
|
+
model: LlamaForCausalLM,
|
|
173
|
+
tokenizer,
|
|
174
|
+
device=torch.device("cuda:0"),
|
|
175
|
+
prune_n=0,
|
|
176
|
+
prune_m=0,
|
|
177
|
+
):
|
|
178
|
+
"""
|
|
179
|
+
Perform Wanda pruning on a Llama model.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
args: The arguments for pruning.
|
|
183
|
+
model (LlamaForCausalLM): The model to prune.
|
|
184
|
+
tokenizer: The tokenizer to use for calibration.
|
|
185
|
+
device (torch.device, optional): The device to use for pruning. Defaults to torch.device("cuda:0").
|
|
186
|
+
prune_n (int, optional): The number of elements to prune in each block. Defaults to 0.
|
|
187
|
+
prune_m (int, optional): The size of each block. Defaults to 0.
|
|
188
|
+
"""
|
|
189
|
+
use_cache = model.config.use_cache
|
|
190
|
+
model.config.use_cache = False
|
|
191
|
+
|
|
192
|
+
with timeit_context("loading calibdation data"):
|
|
193
|
+
dataloader, _ = get_loaders(
|
|
194
|
+
"c4",
|
|
195
|
+
nsamples=args.nsamples,
|
|
196
|
+
seed=args.seed,
|
|
197
|
+
seqlen=model.seqlen,
|
|
198
|
+
tokenizer=tokenizer,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
with torch.no_grad():
|
|
202
|
+
# collect input to the first layer
|
|
203
|
+
inps, outs, attention_mask, position_ids = prepare_calibration_input(
|
|
204
|
+
model, dataloader, device
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
layers = model.model.layers
|
|
208
|
+
for i in range(len(layers)):
|
|
209
|
+
layer = layers[i]
|
|
210
|
+
subset = find_layers(layer)
|
|
211
|
+
|
|
212
|
+
if (
|
|
213
|
+
hasattr(model, "hf_device_map")
|
|
214
|
+
and f"model.layers.{i}" in model.hf_device_map
|
|
215
|
+
): ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs;
|
|
216
|
+
dev = model.hf_device_map[f"model.layers.{i}"]
|
|
217
|
+
inps, outs, attention_mask, position_ids = (
|
|
218
|
+
inps.to(dev),
|
|
219
|
+
outs.to(dev),
|
|
220
|
+
attention_mask.to(dev) if attention_mask is not None else None,
|
|
221
|
+
position_ids.to(dev) if position_ids is not None else None,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
wrapped_layers = {}
|
|
225
|
+
for name in subset:
|
|
226
|
+
wrapped_layers[name] = WrappedGPT(subset[name])
|
|
227
|
+
|
|
228
|
+
def add_batch(name):
|
|
229
|
+
def tmp(_, inp, out):
|
|
230
|
+
cast(WrappedGPT, wrapped_layers[name]).add_batch(inp[0].data, out.data)
|
|
231
|
+
|
|
232
|
+
return tmp
|
|
233
|
+
|
|
234
|
+
handles = []
|
|
235
|
+
for name in wrapped_layers:
|
|
236
|
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
237
|
+
for j in range(args.nsamples):
|
|
238
|
+
with torch.no_grad():
|
|
239
|
+
outs[j] = layer(
|
|
240
|
+
inps[j].unsqueeze(0),
|
|
241
|
+
attention_mask=attention_mask,
|
|
242
|
+
position_ids=position_ids,
|
|
243
|
+
)[0]
|
|
244
|
+
for h in handles:
|
|
245
|
+
h.remove()
|
|
246
|
+
|
|
247
|
+
for name in subset:
|
|
248
|
+
print(f"pruning layer {i} name {name}")
|
|
249
|
+
W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
|
|
250
|
+
wrapped_layers[name].scaler_row.reshape((1, -1))
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
W_mask = (
|
|
254
|
+
torch.zeros_like(W_metric) == 1
|
|
255
|
+
) ## initialize a mask to be all False
|
|
256
|
+
if prune_n != 0:
|
|
257
|
+
# structured n:m sparsity
|
|
258
|
+
for ii in range(W_metric.shape[1]):
|
|
259
|
+
if ii % prune_m == 0:
|
|
260
|
+
tmp = W_metric[:, ii : (ii + prune_m)].float()
|
|
261
|
+
W_mask.scatter_(
|
|
262
|
+
1,
|
|
263
|
+
ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
|
|
264
|
+
True,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
sort_res = torch.sort(W_metric, dim=-1, stable=True)
|
|
268
|
+
|
|
269
|
+
if args.use_variant:
|
|
270
|
+
# wanda variant
|
|
271
|
+
tmp_metric = torch.cumsum(sort_res[0], dim=1)
|
|
272
|
+
sum_before = W_metric.sum(dim=1)
|
|
273
|
+
|
|
274
|
+
alpha = 0.4
|
|
275
|
+
alpha_hist = [0.0, 0.8]
|
|
276
|
+
W_mask, cur_sparsity = return_given_alpha(
|
|
277
|
+
alpha, sort_res, W_metric, tmp_metric, sum_before
|
|
278
|
+
)
|
|
279
|
+
while (torch.abs(cur_sparsity - args.sparsity_ratio) > 0.001) and (
|
|
280
|
+
alpha_hist[1] - alpha_hist[0] >= 0.001
|
|
281
|
+
):
|
|
282
|
+
if cur_sparsity > args.sparsity_ratio:
|
|
283
|
+
alpha_new = (alpha + alpha_hist[0]) / 2.0
|
|
284
|
+
alpha_hist[1] = alpha
|
|
285
|
+
else:
|
|
286
|
+
alpha_new = (alpha + alpha_hist[1]) / 2.0
|
|
287
|
+
alpha_hist[0] = alpha
|
|
288
|
+
|
|
289
|
+
alpha = alpha_new
|
|
290
|
+
W_mask, cur_sparsity = return_given_alpha(
|
|
291
|
+
alpha, sort_res, W_metric, tmp_metric, sum_before
|
|
292
|
+
)
|
|
293
|
+
print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}")
|
|
294
|
+
else:
|
|
295
|
+
# unstructured pruning
|
|
296
|
+
indices = sort_res[1][
|
|
297
|
+
:, : int(W_metric.shape[1] * args.sparsity_ratio)
|
|
298
|
+
]
|
|
299
|
+
W_mask.scatter_(1, indices, True)
|
|
300
|
+
|
|
301
|
+
subset[name].weight.data[W_mask] = 0 ## set weights to zero
|
|
302
|
+
|
|
303
|
+
for j in range(args.nsamples):
|
|
304
|
+
with torch.no_grad():
|
|
305
|
+
outs[j] = layer(
|
|
306
|
+
inps[j].unsqueeze(0),
|
|
307
|
+
attention_mask=attention_mask,
|
|
308
|
+
position_ids=position_ids,
|
|
309
|
+
)[0]
|
|
310
|
+
inps, outs = outs, inps
|
|
311
|
+
|
|
312
|
+
model.config.use_cache = use_cache
|
|
313
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from typing import List, Union
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def layer_load_balance_score(
|
|
7
|
+
number_of_tokens_dispatched: Union[List[int], np.ndarray],
|
|
8
|
+
number_of_experts: int,
|
|
9
|
+
) -> float:
|
|
10
|
+
"""
|
|
11
|
+
Calculate the load balance score for one layer of the MoE model.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
number_of_tokens_dispatched: List[int]
|
|
15
|
+
number_of_experts: int
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
float: The load balance score
|
|
19
|
+
"""
|
|
20
|
+
if len(number_of_tokens_dispatched) != number_of_experts:
|
|
21
|
+
raise ValueError(
|
|
22
|
+
f"The number of tokens dispatched ({len(number_of_tokens_dispatched)}) must match the number of experts ({number_of_experts})"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
number_of_tokens_dispatched = np.array(number_of_tokens_dispatched)
|
|
26
|
+
mu = number_of_tokens_dispatched.mean()
|
|
27
|
+
sigma = np.sqrt(((number_of_tokens_dispatched - mu) ** 2).mean())
|
|
28
|
+
return sigma / mu
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def model_load_balance_score(layer_load_balance_scores: List[float]) -> float:
|
|
32
|
+
"""
|
|
33
|
+
Calculate the load balance score for the whole model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
layer_load_balance_scores: List[float]
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
float: The load balance score
|
|
40
|
+
"""
|
|
41
|
+
return np.array(layer_load_balance_scores).mean()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
2
|
from .llama_magnitude_prune import MagnitudePruningForLlama
|
|
3
3
|
from .llama_random_prune import RandomPruningForLlama
|
|
4
|
+
from .llama_sparsegpt_prune import SparseGPTPruningForLlama
|
|
4
5
|
from .llama_wanda_prune import WandaPruningForLlama
|
|
5
6
|
from .magnitude_diff_pruning import MagnitudeDiffPruningAlgorithm
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Dict, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
from transformers import LlamaForCausalLM
|
|
8
|
+
|
|
9
|
+
from fusion_bench import BaseAlgorithm
|
|
10
|
+
from fusion_bench.method.pruning.prune_utils import (
|
|
11
|
+
PruningType,
|
|
12
|
+
compute_sparsity,
|
|
13
|
+
find_linear_layers,
|
|
14
|
+
semistructured_magnitude_prune_,
|
|
15
|
+
unstructured_magnitude_prune_,
|
|
16
|
+
)
|
|
17
|
+
from fusion_bench.method.pruning.sparsegpt_utils import SparseGPT
|
|
18
|
+
from fusion_bench.method.pruning.wanda_utils.data import get_loaders
|
|
19
|
+
from fusion_bench.method.pruning.wanda_utils.prune import prepare_calibration_input
|
|
20
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
21
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
22
|
+
from fusion_bench.utils import timeit_context
|
|
23
|
+
from fusion_bench.utils.cache_utils import cache_to_disk
|
|
24
|
+
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SparseGPTPruningForLlama(BaseAlgorithm, SimpleProfilerMixin):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
*,
|
|
32
|
+
nsamples: int,
|
|
33
|
+
seed: int,
|
|
34
|
+
use_variant: bool,
|
|
35
|
+
prune_type: PruningType,
|
|
36
|
+
device: str,
|
|
37
|
+
dtype: str,
|
|
38
|
+
sparsity_ratio: float,
|
|
39
|
+
n: int,
|
|
40
|
+
m: int,
|
|
41
|
+
model_save_path: Optional[str] = None,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Initialize the SparseGPTPruningForLlama class.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
nsamples (int): Number of samples for calibration.
|
|
49
|
+
seed (int): Random seed.
|
|
50
|
+
use_variant (bool): Whether to use a variant of the pruning method.
|
|
51
|
+
prune_type (PruningType): Type of pruning to perform.
|
|
52
|
+
device (str): Device to use for computation.
|
|
53
|
+
dtype (str): Data type to use for computation.
|
|
54
|
+
sparsity_ratio (float): Sparsity ratio for pruning.
|
|
55
|
+
n (int): Number of elements to keep in semi-structured pruning.
|
|
56
|
+
m (int): Number of elements in a group for semi-structured pruning.
|
|
57
|
+
model_save_path (Optional[str]): Path to save the pruned model.
|
|
58
|
+
**kwargs: Additional arguments.
|
|
59
|
+
"""
|
|
60
|
+
super().__init__(**kwargs)
|
|
61
|
+
self.nsamples = nsamples
|
|
62
|
+
self.seed = seed
|
|
63
|
+
self.use_variant = use_variant
|
|
64
|
+
self.prune_type = prune_type
|
|
65
|
+
self.device = device
|
|
66
|
+
self.dtype = dtype
|
|
67
|
+
self.sparsity_ratio = sparsity_ratio
|
|
68
|
+
self.n = n
|
|
69
|
+
self.m = m
|
|
70
|
+
self.model_save_path = model_save_path
|
|
71
|
+
|
|
72
|
+
def run(self, modelpool: CausalLMPool):
|
|
73
|
+
# load pre-trained model or the first model in the pool
|
|
74
|
+
with self.profile("load_model"):
|
|
75
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
76
|
+
model.seqlen = model.config.max_position_embeddings
|
|
77
|
+
tokenizer = modelpool.load_tokenizer(use_fast=False)
|
|
78
|
+
|
|
79
|
+
if not isinstance(model, (LlamaForCausalLM,)):
|
|
80
|
+
log.warning(f"Model type {type(model)} may not supported.")
|
|
81
|
+
|
|
82
|
+
inps, outs, attention_mask, position_ids = self.prepare_calibration_data(
|
|
83
|
+
model, tokenizer
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.prune_using_calibration_data_(
|
|
87
|
+
model,
|
|
88
|
+
inps=inps,
|
|
89
|
+
outs=outs,
|
|
90
|
+
attention_mask=attention_mask,
|
|
91
|
+
position_ids=position_ids,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if self.model_save_path is not None:
|
|
95
|
+
with timeit_context(f"Saving pruned model to {self.model_save_path}"):
|
|
96
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
97
|
+
model.save_pretrained(self.model_save_path)
|
|
98
|
+
return model
|
|
99
|
+
|
|
100
|
+
def _prepare_calibration_data(self, model, tokenizer):
|
|
101
|
+
"""
|
|
102
|
+
Prepare calibration data for pruning.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
106
|
+
tokenizer: Tokenizer for the model.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
110
|
+
"""
|
|
111
|
+
with timeit_context("loading calibration data"):
|
|
112
|
+
dataloader, _ = get_loaders(
|
|
113
|
+
"c4",
|
|
114
|
+
nsamples=self.nsamples,
|
|
115
|
+
seed=self.seed,
|
|
116
|
+
seqlen=model.seqlen,
|
|
117
|
+
tokenizer=tokenizer,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
with torch.no_grad():
|
|
121
|
+
# collect input to the first layer
|
|
122
|
+
inps, outs, attention_mask, position_ids = prepare_calibration_input(
|
|
123
|
+
model, dataloader, self.device
|
|
124
|
+
)
|
|
125
|
+
return inps, outs, attention_mask, position_ids
|
|
126
|
+
|
|
127
|
+
def prepare_calibration_data(self, model: LlamaForCausalLM, tokenizer):
|
|
128
|
+
"""
|
|
129
|
+
Prepare calibration data for pruning with caching.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model (LlamaForCausalLM): Model to be pruned.
|
|
133
|
+
tokenizer: Tokenizer for the model.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Tuple: Calibration data (inputs, outputs, attention mask, position IDs).
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
@cache_to_disk(
|
|
140
|
+
f"outputs/cache/{model.config.name_or_path.split('/')[-1]}/calibration_data.pkl"
|
|
141
|
+
)
|
|
142
|
+
def _prepare_calibration_data(model, tokenizer):
|
|
143
|
+
return self._prepare_calibration_data(model, tokenizer)
|
|
144
|
+
|
|
145
|
+
return _prepare_calibration_data(model, tokenizer)
|
|
146
|
+
|
|
147
|
+
@torch.no_grad()
|
|
148
|
+
def prune_using_calibration_data_(
|
|
149
|
+
self,
|
|
150
|
+
model: LlamaForCausalLM,
|
|
151
|
+
*,
|
|
152
|
+
inps,
|
|
153
|
+
outs,
|
|
154
|
+
attention_mask,
|
|
155
|
+
position_ids,
|
|
156
|
+
):
|
|
157
|
+
layers = model.model.layers
|
|
158
|
+
for layer_indx, layer in tqdm(
|
|
159
|
+
enumerate(layers),
|
|
160
|
+
"Pruning Layers",
|
|
161
|
+
total=len(layers),
|
|
162
|
+
dynamic_ncols=True,
|
|
163
|
+
):
|
|
164
|
+
layer = layers[layer_indx]
|
|
165
|
+
if f"model.layers.{layer_indx}" in model.hf_device_map:
|
|
166
|
+
dev = model.hf_device_map[f"model.layers.{layer_indx}"]
|
|
167
|
+
print(f"layer {layer_indx} device {dev}")
|
|
168
|
+
inps, outs, attention_mask, position_ids = (
|
|
169
|
+
inps.to(dev),
|
|
170
|
+
outs.to(dev),
|
|
171
|
+
attention_mask.to(dev),
|
|
172
|
+
position_ids.to(dev),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
subset = find_linear_layers(layer, layers=[nn.Linear])
|
|
176
|
+
|
|
177
|
+
gpts: Dict[str, SparseGPT] = {}
|
|
178
|
+
for name in subset:
|
|
179
|
+
gpts[name] = SparseGPT(subset[name])
|
|
180
|
+
|
|
181
|
+
def add_batch(name):
|
|
182
|
+
def tmp(_, inp, out):
|
|
183
|
+
gpts[name].add_batch(inp[0].data, out.data)
|
|
184
|
+
|
|
185
|
+
return tmp
|
|
186
|
+
|
|
187
|
+
handles = []
|
|
188
|
+
for name in gpts:
|
|
189
|
+
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
190
|
+
|
|
191
|
+
for j in range(self.nsamples):
|
|
192
|
+
outs[j] = layer(
|
|
193
|
+
inps[j].unsqueeze(0),
|
|
194
|
+
attention_mask=attention_mask,
|
|
195
|
+
position_ids=position_ids,
|
|
196
|
+
)[0]
|
|
197
|
+
for h in handles:
|
|
198
|
+
h.remove()
|
|
199
|
+
|
|
200
|
+
for name in gpts:
|
|
201
|
+
print(layer_indx, name)
|
|
202
|
+
print("Pruning ...")
|
|
203
|
+
|
|
204
|
+
gpts[name].fasterprune(
|
|
205
|
+
self.sparsity_ratio,
|
|
206
|
+
prune_n=self.n,
|
|
207
|
+
prune_m=self.m,
|
|
208
|
+
percdamp=0.01,
|
|
209
|
+
blocksize=128,
|
|
210
|
+
)
|
|
211
|
+
gpts[name].free()
|
|
212
|
+
|
|
213
|
+
for j in range(self.nsamples):
|
|
214
|
+
outs[j] = layer(
|
|
215
|
+
inps[j].unsqueeze(0),
|
|
216
|
+
attention_mask=attention_mask,
|
|
217
|
+
position_ids=position_ids,
|
|
218
|
+
)[0]
|
|
219
|
+
|
|
220
|
+
layers[layer_indx] = layer
|
|
221
|
+
torch.cuda.empty_cache()
|
|
222
|
+
|
|
223
|
+
inps, outs = outs, inps
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .sparsegpt import SparseGPT
|