fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/method/__init__.py +4 -0
  2. fusion_bench/method/fw_merging/__init__.py +2 -0
  3. fusion_bench/method/fw_merging/fw_hard.py +448 -0
  4. fusion_bench/method/fw_merging/fw_soft.py +519 -0
  5. fusion_bench/method/fw_merging/utils.py +331 -0
  6. fusion_bench/method/moe_pruner/__init__.py +7 -0
  7. fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
  8. fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
  9. fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
  10. fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
  11. fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
  12. fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
  13. fusion_bench/method/moe_pruner/utils/data.py +154 -0
  14. fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
  15. fusion_bench/method/moe_pruner/utils/prune.py +313 -0
  16. fusion_bench/method/moe_pruner/utils/score.py +41 -0
  17. fusion_bench/method/pruning/__init__.py +1 -0
  18. fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
  19. fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
  20. fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
  21. fusion_bench/method/pruning/wanda_utils/data.py +33 -14
  22. fusion_bench/method/randes/__init__.py +15 -0
  23. fusion_bench/method/randes/base_algorithm.py +1013 -0
  24. fusion_bench/method/randes/modelsoup.py +126 -0
  25. fusion_bench/method/randes/task_arithmetic.py +318 -0
  26. fusion_bench/method/sparselo/sparselo.py +20 -2
  27. fusion_bench/method/tall_mask/__init__.py +1 -0
  28. fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
  29. fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
  30. fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
  31. fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
  32. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
  33. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
  34. fusion_bench/programs/fabric_fusion_program.py +5 -0
  35. fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
  36. fusion_bench/utils/__init__.py +1 -0
  37. fusion_bench/utils/data.py +1 -1
  38. fusion_bench/utils/lazy_state_dict.py +268 -0
  39. fusion_bench/utils/parameters.py +33 -0
  40. fusion_bench/utils/state_dict_arithmetic.py +74 -2
  41. fusion_bench/utils/type.py +1 -0
  42. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
  43. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
  44. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
  45. fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
  46. fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
  47. fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
  48. fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
  49. fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
  50. fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
  51. fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
  52. fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
  53. fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
  54. fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
  55. fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
  56. fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
  57. fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
  58. fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
  59. fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
  60. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
  61. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  62. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  63. fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
  64. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
  65. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
  66. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
  67. fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
  68. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
  69. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
  70. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
  71. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
  72. fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
  73. fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
  74. fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
  75. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ import logging
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+
6
+ from fusion_bench.modelpool import BaseModelPool
7
+ from fusion_bench.utils.parameters import count_parameters
8
+ from fusion_bench.utils.state_dict_arithmetic import (
9
+ state_dict_mul,
10
+ )
11
+
12
+ from .base_algorithm import SuperposedAlgorithmBase, compare_models
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class SuperposedModelSoupAlgorithm(
18
+ SuperposedAlgorithmBase,
19
+ ):
20
+
21
+ @torch.no_grad()
22
+ def run(self, modelpool: BaseModelPool):
23
+ if not isinstance(modelpool, BaseModelPool):
24
+ modelpool = BaseModelPool(models=modelpool)
25
+
26
+ log.info(
27
+ f"Compressing models using superposed model soup.\n"
28
+ f"Models: {modelpool.model_names}"
29
+ )
30
+ models = {}
31
+
32
+ # load state dicts
33
+ state_dicts = self._load_state_dicts(modelpool)
34
+ with self.profile("load model"):
35
+ pretrained_model = modelpool.load_model("_pretrained_")
36
+ absorber_state_dict = self._compute_absorber(state_dicts, pretrained_model)
37
+ if absorber_state_dict is not None:
38
+ state_dicts["absorber"] = absorber_state_dict
39
+
40
+ with self.profile("compress and retrieve"):
41
+ retrieved_state_dicts, metadata = self._compress_and_retrieve(
42
+ deepcopy(state_dicts), mode="superposed_model_soup"
43
+ )
44
+
45
+ with self.profile("retrieve models"):
46
+ for model_idx, model_name in enumerate(modelpool.model_names):
47
+ if self.ms_mode == "average":
48
+ coefficient = 1 / len(modelpool.model_names)
49
+ retrieved_state_dict = state_dict_mul(
50
+ retrieved_state_dicts[model_name], coefficient
51
+ )
52
+ elif self.ms_mode == "original":
53
+ retrieved_state_dict = retrieved_state_dicts[model_name]
54
+ else:
55
+ raise ValueError(f"Unsupported ms_mode: {self.ms_mode}")
56
+ retrieved_model = modelpool.load_model(
57
+ model_name
58
+ ) # TODO: avoid repeated loading
59
+ # FIXME: for 'all' mode
60
+ for k, v in retrieved_state_dict.items():
61
+ if v.shape[0] == 1:
62
+ retrieved_state_dict[k] = v.squeeze(0)
63
+ retrieved_model.load_state_dict(retrieved_state_dict)
64
+ models[model_name] = retrieved_model
65
+ if self.debug >= 1:
66
+ with self.profile("metadata"):
67
+ if torch.cuda.is_available():
68
+ retrieved_state_dicts[model_name] = {
69
+ k: v.cuda()
70
+ for k, v in retrieved_state_dicts[model_name].items()
71
+ }
72
+ state_dicts[model_name] = {
73
+ k: v.cuda() for k, v in state_dicts[model_name].items()
74
+ }
75
+ retrieved_state_dict = {
76
+ k: v.cuda() for k, v in retrieved_state_dict.items()
77
+ }
78
+
79
+ target_layers = metadata["target_layers"]
80
+ # focus on the superposition retrieval performance on the target layers
81
+ metadata["superposed_model_retrieval_similarity"][
82
+ model_name
83
+ ] = compare_models(
84
+ retrieved_state_dicts[model_name],
85
+ state_dicts[model_name],
86
+ target_layers,
87
+ )
88
+ metadata["superposed_model_svd_subspace_similarities"][
89
+ model_name
90
+ ] = self._compute_svd_subspace_similarities(
91
+ state_dicts[model_name],
92
+ retrieved_state_dicts[model_name],
93
+ target_layers,
94
+ )
95
+ # overall retrieval performance
96
+ metadata["model_retrieval_similarity"][model_name] = (
97
+ compare_models(
98
+ retrieved_state_dict, state_dicts[model_name]
99
+ )
100
+ )
101
+ metadata["model_svd_subspace_similarities"][model_name] = (
102
+ self._compute_svd_subspace_similarities(
103
+ state_dicts[model_name], retrieved_state_dict
104
+ )
105
+ )
106
+ # delete the cuda tensors
107
+ del (
108
+ retrieved_state_dicts[model_name],
109
+ state_dicts[model_name],
110
+ retrieved_state_dict,
111
+ )
112
+ with self.profile("metadata"):
113
+ if self.debug >= 0:
114
+ (
115
+ metadata["trainable_param_count_pretrained_model"],
116
+ metadata["active_param_count_pretrained_model"],
117
+ ) = count_parameters(pretrained_model)
118
+ (
119
+ metadata["trainable_param_count_retrieved_model"],
120
+ metadata["active_param_count_retrieved_model"],
121
+ ) = count_parameters(models[modelpool.model_names[0]])
122
+ print(
123
+ f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
124
+ )
125
+ self.print_profile_summary()
126
+ return {"models": models, "metadata": metadata}
@@ -0,0 +1,318 @@
1
+ import logging
2
+ import os
3
+ from collections import OrderedDict
4
+ from copy import deepcopy
5
+ from typing import Optional
6
+
7
+ import torch
8
+
9
+ from fusion_bench.modelpool import BaseModelPool
10
+ from fusion_bench.utils.parameters import count_parameters
11
+ from fusion_bench.utils.state_dict_arithmetic import (
12
+ state_dict_add,
13
+ state_dict_mul,
14
+ state_dict_sub,
15
+ )
16
+
17
+ from .base_algorithm import SuperposedAlgorithmBase, compare_models
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ class SuperposedTaskArithmeticAlgorithm(
23
+ SuperposedAlgorithmBase,
24
+ ):
25
+ _config_mapping = SuperposedAlgorithmBase._config_mapping | {
26
+ "scaling_factor": "scaling_factor",
27
+ "model_path": "model_path",
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ scaling_factor: float,
33
+ model_path: Optional[str] = None,
34
+ **kwargs,
35
+ ):
36
+ super().__init__(**kwargs)
37
+ self.scaling_factor = scaling_factor
38
+ self.model_path = model_path
39
+
40
+ @torch.no_grad()
41
+ def run(self, modelpool: BaseModelPool):
42
+ if not isinstance(modelpool, BaseModelPool):
43
+ modelpool = BaseModelPool(models=modelpool)
44
+
45
+ log.info("Compressing models using superposed task arithmetic.")
46
+ task_vector = None
47
+ with self.profile("load model"):
48
+ pretrained_model = modelpool.load_model("_pretrained_")
49
+
50
+ # Calculate the task vector superposition
51
+ task_vectors = {}
52
+ models = {}
53
+ for model_name in modelpool.model_names:
54
+ with self.profile("load model"):
55
+ model = modelpool.load_model(model_name)
56
+ for layer_name, layer in model.state_dict(keep_vars=True).items():
57
+ if self.verbose >= 1:
58
+ log.info(f"{layer_name} | {layer.shape}")
59
+ task_vector = state_dict_sub(
60
+ model.state_dict(keep_vars=True),
61
+ pretrained_model.state_dict(keep_vars=True),
62
+ )
63
+ task_vectors[model_name] = task_vector
64
+
65
+ with self.profile("compress and retrieve"):
66
+ retrieved_task_vectors, metadata = self._compress_and_retrieve(
67
+ deepcopy(task_vectors), mode="superposed_task_arithmetic"
68
+ )
69
+ with self.profile("retrieve models"):
70
+ for model_name in modelpool.model_names:
71
+ retrieved_task_vector = state_dict_mul(
72
+ retrieved_task_vectors[model_name], self.scaling_factor
73
+ )
74
+ retrieved_state_dict = state_dict_add(
75
+ pretrained_model.state_dict(keep_vars=True), retrieved_task_vector
76
+ )
77
+ retrieved_model = deepcopy(pretrained_model)
78
+ # FIXME: for 'all' mode
79
+ for k, v in retrieved_state_dict.items():
80
+ if v.shape[0] == 1:
81
+ retrieved_state_dict[k] = v.squeeze(0)
82
+ retrieved_model.load_state_dict(retrieved_state_dict)
83
+ models[model_name] = retrieved_model
84
+
85
+ if self.debug >= 1:
86
+ with self.profile("metadata"):
87
+ model = modelpool.load_model(model_name)
88
+ if torch.cuda.is_available():
89
+ retrieved_state_dict = {
90
+ k: v.cuda() for k, v in retrieved_state_dict.items()
91
+ }
92
+ retrieved_task_vectors[model_name] = {
93
+ k: v.cuda()
94
+ for k, v in retrieved_task_vectors[model_name].items()
95
+ }
96
+ task_vectors[model_name] = {
97
+ k: v.cuda() for k, v in task_vectors[model_name].items()
98
+ }
99
+ model_state_dict = {
100
+ k: v.cuda()
101
+ for k, v in model.state_dict(keep_vars=True).items()
102
+ }
103
+ # target_layers = metadata['target_layers']
104
+ metadata["task_vector_retrieval_similarity"][model_name] = (
105
+ compare_models(
106
+ retrieved_task_vectors[model_name],
107
+ task_vectors[model_name],
108
+ )
109
+ )
110
+ metadata["task_vector_svd_subspace_similarities"][
111
+ model_name
112
+ ] = self._compute_svd_subspace_similarities(
113
+ task_vectors[model_name], retrieved_task_vectors[model_name]
114
+ )
115
+ # overall retrieval performance
116
+ metadata["model_retrieval_similarity"][model_name] = (
117
+ compare_models(retrieved_state_dict, model_state_dict)
118
+ )
119
+ metadata["model_svd_subspace_similarities"][model_name] = (
120
+ self._compute_svd_subspace_similarities(
121
+ model_state_dict, retrieved_state_dict
122
+ )
123
+ )
124
+ # delete the cuda tensors
125
+ del (
126
+ retrieved_state_dict,
127
+ retrieved_task_vectors[model_name],
128
+ task_vectors[model_name],
129
+ model_state_dict,
130
+ )
131
+
132
+ with self.profile("metadata"):
133
+ if self.debug >= 0:
134
+ (
135
+ metadata["trainable_param_count_pretrained_model"],
136
+ metadata["active_param_count_pretrained_model"],
137
+ ) = count_parameters(pretrained_model)
138
+ (
139
+ metadata["trainable_param_count_retrieved_model"],
140
+ metadata["active_param_count_retrieved_model"],
141
+ ) = count_parameters(models[modelpool.model_names[0]])
142
+ metadata["nonzero_parameter_count"] += metadata[
143
+ "active_param_count_pretrained_model"
144
+ ]
145
+ metadata["total_gb_retrieved"] += metadata["total_gb_original"]
146
+ print(
147
+ f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
148
+ )
149
+
150
+ if self.model_path is not None:
151
+ os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
152
+ torch.save(models, self.model_path)
153
+
154
+ self.print_profile_summary()
155
+ return {"models": models, "metadata": metadata}
156
+
157
+
158
+ class SuperposedTaskArithmeticLoRAAlgorithm(
159
+ SuperposedAlgorithmBase,
160
+ ):
161
+ _config_mapping = SuperposedAlgorithmBase._config_mapping | {
162
+ "scaling_factor": "scaling_factor",
163
+ "model_path": "model_path",
164
+ }
165
+
166
+ def __init__(
167
+ self,
168
+ scaling_factor: float,
169
+ model_path: Optional[str] = None,
170
+ **kwargs,
171
+ ):
172
+ super().__init__(**kwargs)
173
+ self.scaling_factor = scaling_factor
174
+ self.model_path = model_path
175
+
176
+ @torch.no_grad()
177
+ def run(self, modelpool: BaseModelPool):
178
+ if not isinstance(modelpool, BaseModelPool):
179
+ modelpool = BaseModelPool(models=modelpool)
180
+
181
+ log.info("Compressing models using superposed task arithmetic.")
182
+ task_vector = None
183
+ with self.profile("load model"):
184
+ pretrained_model = modelpool.load_model("_pretrained_")
185
+
186
+ # Calculate the task vector superposition
187
+ loras = {}
188
+ models = {}
189
+ for model_name in modelpool.model_names:
190
+ with self.profile("load model"):
191
+ model = modelpool.load_model(model_name)
192
+ for layer_name, layer in model.items():
193
+ if self.verbose >= 1:
194
+ log.info(f"{layer_name} | {layer.shape}")
195
+ # task_vector = state_dict_sub(
196
+ # model.state_dict(keep_vars=True),
197
+ # pretrained_model.state_dict(keep_vars=True),
198
+ # )
199
+ loras[model_name] = model
200
+
201
+ with self.profile("compress and retrieve"):
202
+ retrieved_loras, metadata = self._compress_and_retrieve(
203
+ deepcopy(loras), mode="superposed_task_arithmetic"
204
+ )
205
+ with self.profile("retrieve models"):
206
+ for model_name in modelpool.model_names:
207
+ retrieved_lora = retrieved_loras[model_name]
208
+ # retrieved_lora = state_dict_mul(retrieved_loras[model_name], self.config.scaling_factor)
209
+ # retrieved_state_dict = state_dict_add(pretrained_model.state_dict(keep_vars=True), retrieved_lora)
210
+ retrieved_model = deepcopy(pretrained_model)
211
+ sd = retrieved_model.state_dict(keep_vars=True)
212
+ # for layer_name, layer in sd.items():
213
+ # print(layer_name)
214
+ # manually merge the lora back
215
+ lora_weights = {}
216
+ lora_weights_ready_to_merge = OrderedDict()
217
+ for layer_name, layer in retrieved_lora.items():
218
+ parts = layer_name.split(".")
219
+ # print(parts)
220
+ base_name = ".".join(parts[2:-2] + [parts[-1]])
221
+ if base_name not in lora_weights:
222
+ lora_weights[base_name] = []
223
+ lora_weights[base_name].append(layer)
224
+ for base_name, layers in lora_weights.items():
225
+ lora_weight = layers[-1] @ layers[0]
226
+ # sd[base_name] += lora_weight
227
+ lora_weights_ready_to_merge[base_name] = lora_weight
228
+
229
+ retrieved_lora_ready = state_dict_mul(
230
+ lora_weights_ready_to_merge, self.config.scaling_factor
231
+ )
232
+ for layer_name, layer in retrieved_lora_ready.items():
233
+ sd[layer_name] += layer
234
+ retrieved_model.load_state_dict(sd)
235
+ models[model_name] = retrieved_model
236
+
237
+ # # FIXME: for 'all' mode
238
+ # for k, v in retrieved_state_dict.items():
239
+ # if v.shape[0] == 1:
240
+ # retrieved_state_dict[k] = v.squeeze(0)
241
+ # retrieved_model.load_state_dict(sd)
242
+ # models[model_name] = retrieved_model
243
+
244
+ if self.debug >= 1:
245
+ with self.profile("metadata"):
246
+ model = modelpool.load_model(model_name)
247
+ if torch.cuda.is_available():
248
+ retrieved_state_dict = {
249
+ k: v.cuda() for k, v in retrieved_state_dict.items()
250
+ }
251
+ retrieved_loras[model_name] = {
252
+ k: v.cuda()
253
+ for k, v in retrieved_loras[model_name].items()
254
+ }
255
+ loras[model_name] = {
256
+ k: v.cuda() for k, v in loras[model_name].items()
257
+ }
258
+ model_state_dict = {
259
+ k: v.cuda()
260
+ for k, v in model.state_dict(keep_vars=True).items()
261
+ }
262
+ # focus on the superposition retrieval performance on the target layers
263
+ target_layers = metadata["target_layers"]
264
+ metadata["lora_retrieval_similarity"][model_name] = (
265
+ compare_models(
266
+ retrieved_loras[model_name],
267
+ loras[model_name],
268
+ target_layers,
269
+ )
270
+ )
271
+ metadata["lora_svd_subspace_similarities"][model_name] = (
272
+ self._compute_svd_subspace_similarities(
273
+ loras[model_name],
274
+ retrieved_loras[model_name],
275
+ target_layers,
276
+ )
277
+ )
278
+ # overall retrieval performance
279
+ metadata["model_retrieval_similarity"][model_name] = (
280
+ compare_models(retrieved_state_dict, model_state_dict)
281
+ )
282
+ metadata["model_svd_subspace_similarities"][model_name] = (
283
+ self._compute_svd_subspace_similarities(
284
+ model_state_dict, retrieved_state_dict
285
+ )
286
+ )
287
+ # delete the cuda tensors
288
+ del (
289
+ retrieved_state_dict,
290
+ retrieved_loras[model_name],
291
+ loras[model_name],
292
+ model_state_dict,
293
+ )
294
+
295
+ with self.profile("metadata"):
296
+ if self.debug >= 0:
297
+ (
298
+ metadata["trainable_param_count_pretrained_model"],
299
+ metadata["active_param_count_pretrained_model"],
300
+ ) = count_parameters(pretrained_model)
301
+ (
302
+ metadata["trainable_param_count_retrieved_model"],
303
+ metadata["active_param_count_retrieved_model"],
304
+ ) = count_parameters(models[modelpool.model_names[0]])
305
+ metadata["nonzero_parameter_count"] += metadata[
306
+ "active_param_count_pretrained_model"
307
+ ]
308
+ metadata["total_gb_retrieved"] += metadata["total_gb_original"]
309
+ print(
310
+ f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
311
+ )
312
+
313
+ if self.model_path is not None:
314
+ os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
315
+ torch.save(models, self.model_path)
316
+
317
+ self.print_profile_summary()
318
+ return {"models": models, "metadata": metadata}
@@ -32,6 +32,7 @@ from fusion_bench.models.modeling_losparse_llama.losparse_linear import LoSparse
32
32
  from fusion_bench.models.modeling_losparse_llama.utils import convert_to_losparse_llama
33
33
  from fusion_bench.utils import cache_to_disk, print_parameters, timeit_context
34
34
  from fusion_bench.utils.devices import get_device
35
+ from fusion_bench.utils.dtype import get_dtype
35
36
 
36
37
  log = logging.getLogger(__name__)
37
38
 
@@ -141,6 +142,7 @@ class SparseLoForLlama(BaseAlgorithm, SimpleProfilerMixin):
141
142
 
142
143
  @override
143
144
  def run(self, modelpool: CausalLMPool):
145
+ self.modelpool = modelpool
144
146
  if self.seed is not None:
145
147
  L.seed_everything(self.seed)
146
148
 
@@ -691,12 +693,16 @@ class IterativeSparseLoForLlama(SparseLoForLlama):
691
693
  "num_iterations": "num_iterations",
692
694
  }
693
695
 
694
- def __init__(self, num_iterations: int, **kwargs):
696
+ def __init__(
697
+ self, num_iterations: int, use_reference_model: bool = False, **kwargs
698
+ ):
695
699
  super().__init__(**kwargs)
696
700
  self.num_iterations = num_iterations
701
+ self.use_reference_model = use_reference_model
697
702
 
698
703
  @override
699
704
  def run(self, modelpool):
705
+ self.modelpool = modelpool
700
706
  if self.seed is not None:
701
707
  L.seed_everything(self.seed)
702
708
 
@@ -802,13 +808,25 @@ class IterativeSparseLoForLlama(SparseLoForLlama):
802
808
  @torch.no_grad()
803
809
  def iterative_magnitude_prune_(self, model):
804
810
  layers: nn.ModuleList = model.model.layers
811
+ if self.use_reference_model:
812
+ reference_model = self.modelpool.load_model(
813
+ "reference_model", torch_dtype="float16"
814
+ )
815
+ reference_layers: nn.ModuleList = reference_model.model.layers
805
816
  for layer_idx, layer in tqdm(
806
817
  enumerate(layers), "Pruning Layers", total=len(layers), dynamic_ncols=True
807
818
  ):
808
819
  for name, linear in layer.named_modules():
809
820
  if isinstance(linear, LoSparseLinear):
810
821
  log.info(f"Magnitude Pruning {name}")
811
- W = linear.weight.data.clone()
822
+ W = (
823
+ linear.weight.data.clone()
824
+ if not self.use_reference_model
825
+ else reference_layers[layer_idx]
826
+ .get_submodule(name)
827
+ .weight.data.clone()
828
+ .to(linear.weight.data.device)
829
+ )
812
830
  if self.prune_type == PruningType.UNSTRUCTURED:
813
831
  unstructured_magnitude_prune_(
814
832
  linear.weight.data,
@@ -0,0 +1 @@
1
+ from .task_arithmetic import TallMaskTaskArithmeticAlgorithm
@@ -0,0 +1,133 @@
1
+ """
2
+ Modified from https://github.com/Zhou-Hangyu/randes/tree/main/benchmark/fusion_bench
3
+ """
4
+
5
+ import logging
6
+ from collections import OrderedDict
7
+ from copy import deepcopy
8
+
9
+ import torch
10
+
11
+ from fusion_bench import BaseAlgorithm
12
+ from fusion_bench.mixins import SimpleProfilerMixin
13
+ from fusion_bench.modelpool import BaseModelPool
14
+ from fusion_bench.utils.state_dict_arithmetic import (
15
+ state_dict_add,
16
+ state_dict_binary_mask,
17
+ state_dict_diff_abs,
18
+ state_dict_hadmard_product,
19
+ state_dict_mul,
20
+ state_dict_sub,
21
+ state_dict_sum,
22
+ )
23
+
24
+ log = logging.getLogger(__name__)
25
+
26
+
27
+ def generate_task_masks(
28
+ multi_task_vector: OrderedDict,
29
+ ft_task_vector: OrderedDict,
30
+ pretrained_task_vector: OrderedDict,
31
+ tall_mask_lambda: float = 1.0,
32
+ ) -> OrderedDict:
33
+ """Adopted from https://github.com/nik-dim/tall_masks/tree/master.
34
+ Generate task-specific TALL masks
35
+ TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda
36
+
37
+ Args:
38
+ multi_task_vector: multi-task vector
39
+ ft_task_vector: individual theta_t (fine-tuned weights)
40
+ pretrained_task_vector: theta_0 (pre-trained weight)
41
+ tall_mask_lambda: hyper-parameter lambda for generating TALL masks
42
+ Returns:
43
+ final_mask: generated TALL masks with the given lambda
44
+ """
45
+
46
+ print(f"Generating TALL masks.")
47
+
48
+ # generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t|
49
+ diff_pt_ft = state_dict_diff_abs(pretrained_task_vector, ft_task_vector)
50
+ diff_multi_ft = state_dict_diff_abs(multi_task_vector, ft_task_vector)
51
+ # compare the l1 distance, scaled with hyper-parameter lambda
52
+ final_mask = state_dict_binary_mask(
53
+ diff_pt_ft,
54
+ state_dict_mul(diff_multi_ft, tall_mask_lambda),
55
+ )
56
+ for key, value in final_mask.items():
57
+ final_mask[key] = value.float()
58
+ return final_mask
59
+
60
+
61
+ class TallMaskTaskArithmeticAlgorithm(
62
+ BaseAlgorithm,
63
+ SimpleProfilerMixin,
64
+ ):
65
+ _config_mapping = BaseAlgorithm._config_mapping | {
66
+ "tall_mask_lambda": "tall_mask_lambda",
67
+ "debug": "debug",
68
+ "verbose": "verbose",
69
+ }
70
+
71
+ def __init__(
72
+ self,
73
+ tall_mask_lambda: float,
74
+ debug: int = 0,
75
+ verbose: int = 0,
76
+ **kwargs,
77
+ ):
78
+ super().__init__(**kwargs)
79
+ self.tall_mask_lambda = tall_mask_lambda
80
+ self.debug = debug
81
+ self.verbose = verbose
82
+
83
+ @torch.no_grad()
84
+ def run(self, modelpool: BaseModelPool):
85
+ if not isinstance(modelpool, BaseModelPool):
86
+ modelpool = BaseModelPool(models=modelpool)
87
+
88
+ log.info("Compressing models using tall mask task arithmetic.")
89
+ task_vector = None
90
+ with self.profile("load model"):
91
+ pretrained_model = modelpool.load_model("_pretrained_")
92
+
93
+ task_vectors = {}
94
+ models = {}
95
+ for model_name in modelpool.model_names:
96
+ with self.profile("load model"):
97
+ model = modelpool.load_model(model_name)
98
+ for layer_name, layer in model.state_dict(keep_vars=True).items():
99
+ if self.verbose >= 1:
100
+ log.info(f"{layer_name} | {layer.shape}")
101
+ task_vector = state_dict_sub(
102
+ model.state_dict(keep_vars=True),
103
+ pretrained_model.state_dict(keep_vars=True),
104
+ )
105
+ task_vectors[model_name] = task_vector
106
+
107
+ multi_task_vector = state_dict_sum(list(task_vectors.values()))
108
+
109
+ tall_masks = {model: {} for model in modelpool.model_names}
110
+
111
+ for model_name in modelpool.model_names:
112
+ tall_mask = generate_task_masks(
113
+ multi_task_vector,
114
+ task_vectors[model_name],
115
+ pretrained_model.state_dict(keep_vars=True),
116
+ tall_mask_lambda=self.tall_mask_lambda,
117
+ )
118
+ tall_masks[model_name] = tall_mask
119
+
120
+ with self.profile("compress and retrieve"):
121
+ for model_name in modelpool.model_names:
122
+ retrieved_task_vector = state_dict_hadmard_product(
123
+ tall_masks[model_name], multi_task_vector
124
+ )
125
+ retrieved_state_dict = state_dict_add(
126
+ pretrained_model.state_dict(keep_vars=True), retrieved_task_vector
127
+ )
128
+ retrieved_model = deepcopy(pretrained_model)
129
+ retrieved_model.load_state_dict(retrieved_state_dict)
130
+ models[model_name] = retrieved_model
131
+
132
+ self.print_profile_summary()
133
+ return {"models": models, "metadata": None}
@@ -0,0 +1,15 @@
1
+ from fusion_bench import BaseModelPool
2
+ from fusion_bench.utils import instantiate
3
+ from fusion_bench.utils.lazy_state_dict import LazyStateDict
4
+
5
+
6
+ class LazyStateDictPool(BaseModelPool):
7
+ def load_model(self, model_name_or_config: str, *args, **kwargs) -> LazyStateDict:
8
+ if model_name_or_config in self._models:
9
+ checkpoint_config = self._models[model_name_or_config]
10
+ else:
11
+ checkpoint_config = model_name_or_config
12
+ if isinstance(checkpoint_config, str):
13
+ return LazyStateDict(checkpoint_config, *args, **kwargs)
14
+ else:
15
+ return instantiate(checkpoint_config, *args, **kwargs)
@@ -0,0 +1,15 @@
1
+ """
2
+ This is a direct copy of the DeepSeek-V2-Lite model from HuggingFace https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/tree/main
3
+ """
4
+
5
+ from .configuration_deepseek import DeepseekV2Config
6
+ from .modeling_deepseek import (
7
+ DeepseekV2ForCausalLM,
8
+ DeepseekV2ForSequenceClassification,
9
+ DeepseekV2MLP,
10
+ DeepseekV2Model,
11
+ DeepseekV2MoE,
12
+ DeepseekV2DecoderLayer,
13
+ )
14
+ from .modeling_deepseek import MoEGate as DeepseekV2MoEGate
15
+ from .tokenization_deepseek_fast import DeepseekTokenizerFast