fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +14 -5
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/pylogger.py +28 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from copy import deepcopy
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
7
|
+
from fusion_bench.utils.parameters import count_parameters
|
|
8
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
9
|
+
state_dict_mul,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from .base_algorithm import SuperposedAlgorithmBase, compare_models
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SuperposedModelSoupAlgorithm(
|
|
18
|
+
SuperposedAlgorithmBase,
|
|
19
|
+
):
|
|
20
|
+
|
|
21
|
+
@torch.no_grad()
|
|
22
|
+
def run(self, modelpool: BaseModelPool):
|
|
23
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
24
|
+
modelpool = BaseModelPool(models=modelpool)
|
|
25
|
+
|
|
26
|
+
log.info(
|
|
27
|
+
f"Compressing models using superposed model soup.\n"
|
|
28
|
+
f"Models: {modelpool.model_names}"
|
|
29
|
+
)
|
|
30
|
+
models = {}
|
|
31
|
+
|
|
32
|
+
# load state dicts
|
|
33
|
+
state_dicts = self._load_state_dicts(modelpool)
|
|
34
|
+
with self.profile("load model"):
|
|
35
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
36
|
+
absorber_state_dict = self._compute_absorber(state_dicts, pretrained_model)
|
|
37
|
+
if absorber_state_dict is not None:
|
|
38
|
+
state_dicts["absorber"] = absorber_state_dict
|
|
39
|
+
|
|
40
|
+
with self.profile("compress and retrieve"):
|
|
41
|
+
retrieved_state_dicts, metadata = self._compress_and_retrieve(
|
|
42
|
+
deepcopy(state_dicts), mode="superposed_model_soup"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
with self.profile("retrieve models"):
|
|
46
|
+
for model_idx, model_name in enumerate(modelpool.model_names):
|
|
47
|
+
if self.ms_mode == "average":
|
|
48
|
+
coefficient = 1 / len(modelpool.model_names)
|
|
49
|
+
retrieved_state_dict = state_dict_mul(
|
|
50
|
+
retrieved_state_dicts[model_name], coefficient
|
|
51
|
+
)
|
|
52
|
+
elif self.ms_mode == "original":
|
|
53
|
+
retrieved_state_dict = retrieved_state_dicts[model_name]
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unsupported ms_mode: {self.ms_mode}")
|
|
56
|
+
retrieved_model = modelpool.load_model(
|
|
57
|
+
model_name
|
|
58
|
+
) # TODO: avoid repeated loading
|
|
59
|
+
# FIXME: for 'all' mode
|
|
60
|
+
for k, v in retrieved_state_dict.items():
|
|
61
|
+
if v.shape[0] == 1:
|
|
62
|
+
retrieved_state_dict[k] = v.squeeze(0)
|
|
63
|
+
retrieved_model.load_state_dict(retrieved_state_dict)
|
|
64
|
+
models[model_name] = retrieved_model
|
|
65
|
+
if self.debug >= 1:
|
|
66
|
+
with self.profile("metadata"):
|
|
67
|
+
if torch.cuda.is_available():
|
|
68
|
+
retrieved_state_dicts[model_name] = {
|
|
69
|
+
k: v.cuda()
|
|
70
|
+
for k, v in retrieved_state_dicts[model_name].items()
|
|
71
|
+
}
|
|
72
|
+
state_dicts[model_name] = {
|
|
73
|
+
k: v.cuda() for k, v in state_dicts[model_name].items()
|
|
74
|
+
}
|
|
75
|
+
retrieved_state_dict = {
|
|
76
|
+
k: v.cuda() for k, v in retrieved_state_dict.items()
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
target_layers = metadata["target_layers"]
|
|
80
|
+
# focus on the superposition retrieval performance on the target layers
|
|
81
|
+
metadata["superposed_model_retrieval_similarity"][
|
|
82
|
+
model_name
|
|
83
|
+
] = compare_models(
|
|
84
|
+
retrieved_state_dicts[model_name],
|
|
85
|
+
state_dicts[model_name],
|
|
86
|
+
target_layers,
|
|
87
|
+
)
|
|
88
|
+
metadata["superposed_model_svd_subspace_similarities"][
|
|
89
|
+
model_name
|
|
90
|
+
] = self._compute_svd_subspace_similarities(
|
|
91
|
+
state_dicts[model_name],
|
|
92
|
+
retrieved_state_dicts[model_name],
|
|
93
|
+
target_layers,
|
|
94
|
+
)
|
|
95
|
+
# overall retrieval performance
|
|
96
|
+
metadata["model_retrieval_similarity"][model_name] = (
|
|
97
|
+
compare_models(
|
|
98
|
+
retrieved_state_dict, state_dicts[model_name]
|
|
99
|
+
)
|
|
100
|
+
)
|
|
101
|
+
metadata["model_svd_subspace_similarities"][model_name] = (
|
|
102
|
+
self._compute_svd_subspace_similarities(
|
|
103
|
+
state_dicts[model_name], retrieved_state_dict
|
|
104
|
+
)
|
|
105
|
+
)
|
|
106
|
+
# delete the cuda tensors
|
|
107
|
+
del (
|
|
108
|
+
retrieved_state_dicts[model_name],
|
|
109
|
+
state_dicts[model_name],
|
|
110
|
+
retrieved_state_dict,
|
|
111
|
+
)
|
|
112
|
+
with self.profile("metadata"):
|
|
113
|
+
if self.debug >= 0:
|
|
114
|
+
(
|
|
115
|
+
metadata["trainable_param_count_pretrained_model"],
|
|
116
|
+
metadata["active_param_count_pretrained_model"],
|
|
117
|
+
) = count_parameters(pretrained_model)
|
|
118
|
+
(
|
|
119
|
+
metadata["trainable_param_count_retrieved_model"],
|
|
120
|
+
metadata["active_param_count_retrieved_model"],
|
|
121
|
+
) = count_parameters(models[modelpool.model_names[0]])
|
|
122
|
+
print(
|
|
123
|
+
f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
|
|
124
|
+
)
|
|
125
|
+
self.print_profile_summary()
|
|
126
|
+
return {"models": models, "metadata": metadata}
|
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
10
|
+
from fusion_bench.utils.parameters import count_parameters
|
|
11
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
|
+
state_dict_add,
|
|
13
|
+
state_dict_mul,
|
|
14
|
+
state_dict_sub,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .base_algorithm import SuperposedAlgorithmBase, compare_models
|
|
18
|
+
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SuperposedTaskArithmeticAlgorithm(
|
|
23
|
+
SuperposedAlgorithmBase,
|
|
24
|
+
):
|
|
25
|
+
_config_mapping = SuperposedAlgorithmBase._config_mapping | {
|
|
26
|
+
"scaling_factor": "scaling_factor",
|
|
27
|
+
"model_path": "model_path",
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
scaling_factor: float,
|
|
33
|
+
model_path: Optional[str] = None,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(**kwargs)
|
|
37
|
+
self.scaling_factor = scaling_factor
|
|
38
|
+
self.model_path = model_path
|
|
39
|
+
|
|
40
|
+
@torch.no_grad()
|
|
41
|
+
def run(self, modelpool: BaseModelPool):
|
|
42
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
43
|
+
modelpool = BaseModelPool(models=modelpool)
|
|
44
|
+
|
|
45
|
+
log.info("Compressing models using superposed task arithmetic.")
|
|
46
|
+
task_vector = None
|
|
47
|
+
with self.profile("load model"):
|
|
48
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
49
|
+
|
|
50
|
+
# Calculate the task vector superposition
|
|
51
|
+
task_vectors = {}
|
|
52
|
+
models = {}
|
|
53
|
+
for model_name in modelpool.model_names:
|
|
54
|
+
with self.profile("load model"):
|
|
55
|
+
model = modelpool.load_model(model_name)
|
|
56
|
+
for layer_name, layer in model.state_dict(keep_vars=True).items():
|
|
57
|
+
if self.verbose >= 1:
|
|
58
|
+
log.info(f"{layer_name} | {layer.shape}")
|
|
59
|
+
task_vector = state_dict_sub(
|
|
60
|
+
model.state_dict(keep_vars=True),
|
|
61
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
62
|
+
)
|
|
63
|
+
task_vectors[model_name] = task_vector
|
|
64
|
+
|
|
65
|
+
with self.profile("compress and retrieve"):
|
|
66
|
+
retrieved_task_vectors, metadata = self._compress_and_retrieve(
|
|
67
|
+
deepcopy(task_vectors), mode="superposed_task_arithmetic"
|
|
68
|
+
)
|
|
69
|
+
with self.profile("retrieve models"):
|
|
70
|
+
for model_name in modelpool.model_names:
|
|
71
|
+
retrieved_task_vector = state_dict_mul(
|
|
72
|
+
retrieved_task_vectors[model_name], self.scaling_factor
|
|
73
|
+
)
|
|
74
|
+
retrieved_state_dict = state_dict_add(
|
|
75
|
+
pretrained_model.state_dict(keep_vars=True), retrieved_task_vector
|
|
76
|
+
)
|
|
77
|
+
retrieved_model = deepcopy(pretrained_model)
|
|
78
|
+
# FIXME: for 'all' mode
|
|
79
|
+
for k, v in retrieved_state_dict.items():
|
|
80
|
+
if v.shape[0] == 1:
|
|
81
|
+
retrieved_state_dict[k] = v.squeeze(0)
|
|
82
|
+
retrieved_model.load_state_dict(retrieved_state_dict)
|
|
83
|
+
models[model_name] = retrieved_model
|
|
84
|
+
|
|
85
|
+
if self.debug >= 1:
|
|
86
|
+
with self.profile("metadata"):
|
|
87
|
+
model = modelpool.load_model(model_name)
|
|
88
|
+
if torch.cuda.is_available():
|
|
89
|
+
retrieved_state_dict = {
|
|
90
|
+
k: v.cuda() for k, v in retrieved_state_dict.items()
|
|
91
|
+
}
|
|
92
|
+
retrieved_task_vectors[model_name] = {
|
|
93
|
+
k: v.cuda()
|
|
94
|
+
for k, v in retrieved_task_vectors[model_name].items()
|
|
95
|
+
}
|
|
96
|
+
task_vectors[model_name] = {
|
|
97
|
+
k: v.cuda() for k, v in task_vectors[model_name].items()
|
|
98
|
+
}
|
|
99
|
+
model_state_dict = {
|
|
100
|
+
k: v.cuda()
|
|
101
|
+
for k, v in model.state_dict(keep_vars=True).items()
|
|
102
|
+
}
|
|
103
|
+
# target_layers = metadata['target_layers']
|
|
104
|
+
metadata["task_vector_retrieval_similarity"][model_name] = (
|
|
105
|
+
compare_models(
|
|
106
|
+
retrieved_task_vectors[model_name],
|
|
107
|
+
task_vectors[model_name],
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
metadata["task_vector_svd_subspace_similarities"][
|
|
111
|
+
model_name
|
|
112
|
+
] = self._compute_svd_subspace_similarities(
|
|
113
|
+
task_vectors[model_name], retrieved_task_vectors[model_name]
|
|
114
|
+
)
|
|
115
|
+
# overall retrieval performance
|
|
116
|
+
metadata["model_retrieval_similarity"][model_name] = (
|
|
117
|
+
compare_models(retrieved_state_dict, model_state_dict)
|
|
118
|
+
)
|
|
119
|
+
metadata["model_svd_subspace_similarities"][model_name] = (
|
|
120
|
+
self._compute_svd_subspace_similarities(
|
|
121
|
+
model_state_dict, retrieved_state_dict
|
|
122
|
+
)
|
|
123
|
+
)
|
|
124
|
+
# delete the cuda tensors
|
|
125
|
+
del (
|
|
126
|
+
retrieved_state_dict,
|
|
127
|
+
retrieved_task_vectors[model_name],
|
|
128
|
+
task_vectors[model_name],
|
|
129
|
+
model_state_dict,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
with self.profile("metadata"):
|
|
133
|
+
if self.debug >= 0:
|
|
134
|
+
(
|
|
135
|
+
metadata["trainable_param_count_pretrained_model"],
|
|
136
|
+
metadata["active_param_count_pretrained_model"],
|
|
137
|
+
) = count_parameters(pretrained_model)
|
|
138
|
+
(
|
|
139
|
+
metadata["trainable_param_count_retrieved_model"],
|
|
140
|
+
metadata["active_param_count_retrieved_model"],
|
|
141
|
+
) = count_parameters(models[modelpool.model_names[0]])
|
|
142
|
+
metadata["nonzero_parameter_count"] += metadata[
|
|
143
|
+
"active_param_count_pretrained_model"
|
|
144
|
+
]
|
|
145
|
+
metadata["total_gb_retrieved"] += metadata["total_gb_original"]
|
|
146
|
+
print(
|
|
147
|
+
f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
if self.model_path is not None:
|
|
151
|
+
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
|
|
152
|
+
torch.save(models, self.model_path)
|
|
153
|
+
|
|
154
|
+
self.print_profile_summary()
|
|
155
|
+
return {"models": models, "metadata": metadata}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SuperposedTaskArithmeticLoRAAlgorithm(
|
|
159
|
+
SuperposedAlgorithmBase,
|
|
160
|
+
):
|
|
161
|
+
_config_mapping = SuperposedAlgorithmBase._config_mapping | {
|
|
162
|
+
"scaling_factor": "scaling_factor",
|
|
163
|
+
"model_path": "model_path",
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
def __init__(
|
|
167
|
+
self,
|
|
168
|
+
scaling_factor: float,
|
|
169
|
+
model_path: Optional[str] = None,
|
|
170
|
+
**kwargs,
|
|
171
|
+
):
|
|
172
|
+
super().__init__(**kwargs)
|
|
173
|
+
self.scaling_factor = scaling_factor
|
|
174
|
+
self.model_path = model_path
|
|
175
|
+
|
|
176
|
+
@torch.no_grad()
|
|
177
|
+
def run(self, modelpool: BaseModelPool):
|
|
178
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
179
|
+
modelpool = BaseModelPool(models=modelpool)
|
|
180
|
+
|
|
181
|
+
log.info("Compressing models using superposed task arithmetic.")
|
|
182
|
+
task_vector = None
|
|
183
|
+
with self.profile("load model"):
|
|
184
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
185
|
+
|
|
186
|
+
# Calculate the task vector superposition
|
|
187
|
+
loras = {}
|
|
188
|
+
models = {}
|
|
189
|
+
for model_name in modelpool.model_names:
|
|
190
|
+
with self.profile("load model"):
|
|
191
|
+
model = modelpool.load_model(model_name)
|
|
192
|
+
for layer_name, layer in model.items():
|
|
193
|
+
if self.verbose >= 1:
|
|
194
|
+
log.info(f"{layer_name} | {layer.shape}")
|
|
195
|
+
# task_vector = state_dict_sub(
|
|
196
|
+
# model.state_dict(keep_vars=True),
|
|
197
|
+
# pretrained_model.state_dict(keep_vars=True),
|
|
198
|
+
# )
|
|
199
|
+
loras[model_name] = model
|
|
200
|
+
|
|
201
|
+
with self.profile("compress and retrieve"):
|
|
202
|
+
retrieved_loras, metadata = self._compress_and_retrieve(
|
|
203
|
+
deepcopy(loras), mode="superposed_task_arithmetic"
|
|
204
|
+
)
|
|
205
|
+
with self.profile("retrieve models"):
|
|
206
|
+
for model_name in modelpool.model_names:
|
|
207
|
+
retrieved_lora = retrieved_loras[model_name]
|
|
208
|
+
# retrieved_lora = state_dict_mul(retrieved_loras[model_name], self.config.scaling_factor)
|
|
209
|
+
# retrieved_state_dict = state_dict_add(pretrained_model.state_dict(keep_vars=True), retrieved_lora)
|
|
210
|
+
retrieved_model = deepcopy(pretrained_model)
|
|
211
|
+
sd = retrieved_model.state_dict(keep_vars=True)
|
|
212
|
+
# for layer_name, layer in sd.items():
|
|
213
|
+
# print(layer_name)
|
|
214
|
+
# manually merge the lora back
|
|
215
|
+
lora_weights = {}
|
|
216
|
+
lora_weights_ready_to_merge = OrderedDict()
|
|
217
|
+
for layer_name, layer in retrieved_lora.items():
|
|
218
|
+
parts = layer_name.split(".")
|
|
219
|
+
# print(parts)
|
|
220
|
+
base_name = ".".join(parts[2:-2] + [parts[-1]])
|
|
221
|
+
if base_name not in lora_weights:
|
|
222
|
+
lora_weights[base_name] = []
|
|
223
|
+
lora_weights[base_name].append(layer)
|
|
224
|
+
for base_name, layers in lora_weights.items():
|
|
225
|
+
lora_weight = layers[-1] @ layers[0]
|
|
226
|
+
# sd[base_name] += lora_weight
|
|
227
|
+
lora_weights_ready_to_merge[base_name] = lora_weight
|
|
228
|
+
|
|
229
|
+
retrieved_lora_ready = state_dict_mul(
|
|
230
|
+
lora_weights_ready_to_merge, self.config.scaling_factor
|
|
231
|
+
)
|
|
232
|
+
for layer_name, layer in retrieved_lora_ready.items():
|
|
233
|
+
sd[layer_name] += layer
|
|
234
|
+
retrieved_model.load_state_dict(sd)
|
|
235
|
+
models[model_name] = retrieved_model
|
|
236
|
+
|
|
237
|
+
# # FIXME: for 'all' mode
|
|
238
|
+
# for k, v in retrieved_state_dict.items():
|
|
239
|
+
# if v.shape[0] == 1:
|
|
240
|
+
# retrieved_state_dict[k] = v.squeeze(0)
|
|
241
|
+
# retrieved_model.load_state_dict(sd)
|
|
242
|
+
# models[model_name] = retrieved_model
|
|
243
|
+
|
|
244
|
+
if self.debug >= 1:
|
|
245
|
+
with self.profile("metadata"):
|
|
246
|
+
model = modelpool.load_model(model_name)
|
|
247
|
+
if torch.cuda.is_available():
|
|
248
|
+
retrieved_state_dict = {
|
|
249
|
+
k: v.cuda() for k, v in retrieved_state_dict.items()
|
|
250
|
+
}
|
|
251
|
+
retrieved_loras[model_name] = {
|
|
252
|
+
k: v.cuda()
|
|
253
|
+
for k, v in retrieved_loras[model_name].items()
|
|
254
|
+
}
|
|
255
|
+
loras[model_name] = {
|
|
256
|
+
k: v.cuda() for k, v in loras[model_name].items()
|
|
257
|
+
}
|
|
258
|
+
model_state_dict = {
|
|
259
|
+
k: v.cuda()
|
|
260
|
+
for k, v in model.state_dict(keep_vars=True).items()
|
|
261
|
+
}
|
|
262
|
+
# focus on the superposition retrieval performance on the target layers
|
|
263
|
+
target_layers = metadata["target_layers"]
|
|
264
|
+
metadata["lora_retrieval_similarity"][model_name] = (
|
|
265
|
+
compare_models(
|
|
266
|
+
retrieved_loras[model_name],
|
|
267
|
+
loras[model_name],
|
|
268
|
+
target_layers,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
metadata["lora_svd_subspace_similarities"][model_name] = (
|
|
272
|
+
self._compute_svd_subspace_similarities(
|
|
273
|
+
loras[model_name],
|
|
274
|
+
retrieved_loras[model_name],
|
|
275
|
+
target_layers,
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
# overall retrieval performance
|
|
279
|
+
metadata["model_retrieval_similarity"][model_name] = (
|
|
280
|
+
compare_models(retrieved_state_dict, model_state_dict)
|
|
281
|
+
)
|
|
282
|
+
metadata["model_svd_subspace_similarities"][model_name] = (
|
|
283
|
+
self._compute_svd_subspace_similarities(
|
|
284
|
+
model_state_dict, retrieved_state_dict
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
# delete the cuda tensors
|
|
288
|
+
del (
|
|
289
|
+
retrieved_state_dict,
|
|
290
|
+
retrieved_loras[model_name],
|
|
291
|
+
loras[model_name],
|
|
292
|
+
model_state_dict,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
with self.profile("metadata"):
|
|
296
|
+
if self.debug >= 0:
|
|
297
|
+
(
|
|
298
|
+
metadata["trainable_param_count_pretrained_model"],
|
|
299
|
+
metadata["active_param_count_pretrained_model"],
|
|
300
|
+
) = count_parameters(pretrained_model)
|
|
301
|
+
(
|
|
302
|
+
metadata["trainable_param_count_retrieved_model"],
|
|
303
|
+
metadata["active_param_count_retrieved_model"],
|
|
304
|
+
) = count_parameters(models[modelpool.model_names[0]])
|
|
305
|
+
metadata["nonzero_parameter_count"] += metadata[
|
|
306
|
+
"active_param_count_pretrained_model"
|
|
307
|
+
]
|
|
308
|
+
metadata["total_gb_retrieved"] += metadata["total_gb_original"]
|
|
309
|
+
print(
|
|
310
|
+
f"Total storage (Gbs) for retrieval and original: {metadata['total_gb_retrieved']} | {metadata['total_gb_original']}"
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if self.model_path is not None:
|
|
314
|
+
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
|
|
315
|
+
torch.save(models, self.model_path)
|
|
316
|
+
|
|
317
|
+
self.print_profile_summary()
|
|
318
|
+
return {"models": models, "metadata": metadata}
|
|
@@ -39,12 +39,13 @@ def simple_average(
|
|
|
39
39
|
>>> import torch.nn as nn
|
|
40
40
|
>>> model1 = nn.Linear(10, 10)
|
|
41
41
|
>>> model2 = nn.Linear(10, 10)
|
|
42
|
-
>>> averaged_model =
|
|
42
|
+
>>> averaged_model = simple_average([model1, model2])
|
|
43
43
|
|
|
44
44
|
>>> state_dict1 = model1.state_dict()
|
|
45
45
|
>>> state_dict2 = model2.state_dict()
|
|
46
|
-
>>> averaged_state_dict =
|
|
46
|
+
>>> averaged_state_dict = simple_average([state_dict1, state_dict2])
|
|
47
47
|
"""
|
|
48
|
+
assert len(modules) > 0, "modules must be a non-empty list"
|
|
48
49
|
if isinstance(modules[0], nn.Module):
|
|
49
50
|
if base_module is None:
|
|
50
51
|
new_module = deepcopy(modules[0])
|
|
@@ -32,6 +32,7 @@ from fusion_bench.models.modeling_losparse_llama.losparse_linear import LoSparse
|
|
|
32
32
|
from fusion_bench.models.modeling_losparse_llama.utils import convert_to_losparse_llama
|
|
33
33
|
from fusion_bench.utils import cache_to_disk, print_parameters, timeit_context
|
|
34
34
|
from fusion_bench.utils.devices import get_device
|
|
35
|
+
from fusion_bench.utils.dtype import get_dtype
|
|
35
36
|
|
|
36
37
|
log = logging.getLogger(__name__)
|
|
37
38
|
|
|
@@ -141,6 +142,7 @@ class SparseLoForLlama(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
141
142
|
|
|
142
143
|
@override
|
|
143
144
|
def run(self, modelpool: CausalLMPool):
|
|
145
|
+
self.modelpool = modelpool
|
|
144
146
|
if self.seed is not None:
|
|
145
147
|
L.seed_everything(self.seed)
|
|
146
148
|
|
|
@@ -691,12 +693,16 @@ class IterativeSparseLoForLlama(SparseLoForLlama):
|
|
|
691
693
|
"num_iterations": "num_iterations",
|
|
692
694
|
}
|
|
693
695
|
|
|
694
|
-
def __init__(
|
|
696
|
+
def __init__(
|
|
697
|
+
self, num_iterations: int, use_reference_model: bool = False, **kwargs
|
|
698
|
+
):
|
|
695
699
|
super().__init__(**kwargs)
|
|
696
700
|
self.num_iterations = num_iterations
|
|
701
|
+
self.use_reference_model = use_reference_model
|
|
697
702
|
|
|
698
703
|
@override
|
|
699
704
|
def run(self, modelpool):
|
|
705
|
+
self.modelpool = modelpool
|
|
700
706
|
if self.seed is not None:
|
|
701
707
|
L.seed_everything(self.seed)
|
|
702
708
|
|
|
@@ -802,13 +808,25 @@ class IterativeSparseLoForLlama(SparseLoForLlama):
|
|
|
802
808
|
@torch.no_grad()
|
|
803
809
|
def iterative_magnitude_prune_(self, model):
|
|
804
810
|
layers: nn.ModuleList = model.model.layers
|
|
811
|
+
if self.use_reference_model:
|
|
812
|
+
reference_model = self.modelpool.load_model(
|
|
813
|
+
"reference_model", torch_dtype="float16"
|
|
814
|
+
)
|
|
815
|
+
reference_layers: nn.ModuleList = reference_model.model.layers
|
|
805
816
|
for layer_idx, layer in tqdm(
|
|
806
817
|
enumerate(layers), "Pruning Layers", total=len(layers), dynamic_ncols=True
|
|
807
818
|
):
|
|
808
819
|
for name, linear in layer.named_modules():
|
|
809
820
|
if isinstance(linear, LoSparseLinear):
|
|
810
821
|
log.info(f"Magnitude Pruning {name}")
|
|
811
|
-
W =
|
|
822
|
+
W = (
|
|
823
|
+
linear.weight.data.clone()
|
|
824
|
+
if not self.use_reference_model
|
|
825
|
+
else reference_layers[layer_idx]
|
|
826
|
+
.get_submodule(name)
|
|
827
|
+
.weight.data.clone()
|
|
828
|
+
.to(linear.weight.data.device)
|
|
829
|
+
)
|
|
812
830
|
if self.prune_type == PruningType.UNSTRUCTURED:
|
|
813
831
|
unstructured_magnitude_prune_(
|
|
814
832
|
linear.weight.data,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .task_arithmetic import TallMaskTaskArithmeticAlgorithm
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Modified from https://github.com/Zhou-Hangyu/randes/tree/main/benchmark/fusion_bench
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from fusion_bench import BaseAlgorithm
|
|
12
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
13
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
14
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
|
+
state_dict_add,
|
|
16
|
+
state_dict_binary_mask,
|
|
17
|
+
state_dict_diff_abs,
|
|
18
|
+
state_dict_hadmard_product,
|
|
19
|
+
state_dict_mul,
|
|
20
|
+
state_dict_sub,
|
|
21
|
+
state_dict_sum,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def generate_task_masks(
|
|
28
|
+
multi_task_vector: OrderedDict,
|
|
29
|
+
ft_task_vector: OrderedDict,
|
|
30
|
+
pretrained_task_vector: OrderedDict,
|
|
31
|
+
tall_mask_lambda: float = 1.0,
|
|
32
|
+
) -> OrderedDict:
|
|
33
|
+
"""Adopted from https://github.com/nik-dim/tall_masks/tree/master.
|
|
34
|
+
Generate task-specific TALL masks
|
|
35
|
+
TALL masks are generated as: mask_t = |theta_0 - theta_t| > |theta_mt - theta_t| * lambda
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
multi_task_vector: multi-task vector
|
|
39
|
+
ft_task_vector: individual theta_t (fine-tuned weights)
|
|
40
|
+
pretrained_task_vector: theta_0 (pre-trained weight)
|
|
41
|
+
tall_mask_lambda: hyper-parameter lambda for generating TALL masks
|
|
42
|
+
Returns:
|
|
43
|
+
final_mask: generated TALL masks with the given lambda
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
print(f"Generating TALL masks.")
|
|
47
|
+
|
|
48
|
+
# generate masks by comparing the l1 distance between |theta_0 - theta_t| and |theta_mt - theta_t|
|
|
49
|
+
diff_pt_ft = state_dict_diff_abs(pretrained_task_vector, ft_task_vector)
|
|
50
|
+
diff_multi_ft = state_dict_diff_abs(multi_task_vector, ft_task_vector)
|
|
51
|
+
# compare the l1 distance, scaled with hyper-parameter lambda
|
|
52
|
+
final_mask = state_dict_binary_mask(
|
|
53
|
+
diff_pt_ft,
|
|
54
|
+
state_dict_mul(diff_multi_ft, tall_mask_lambda),
|
|
55
|
+
)
|
|
56
|
+
for key, value in final_mask.items():
|
|
57
|
+
final_mask[key] = value.float()
|
|
58
|
+
return final_mask
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TallMaskTaskArithmeticAlgorithm(
|
|
62
|
+
BaseAlgorithm,
|
|
63
|
+
SimpleProfilerMixin,
|
|
64
|
+
):
|
|
65
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
66
|
+
"tall_mask_lambda": "tall_mask_lambda",
|
|
67
|
+
"debug": "debug",
|
|
68
|
+
"verbose": "verbose",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
tall_mask_lambda: float,
|
|
74
|
+
debug: int = 0,
|
|
75
|
+
verbose: int = 0,
|
|
76
|
+
**kwargs,
|
|
77
|
+
):
|
|
78
|
+
super().__init__(**kwargs)
|
|
79
|
+
self.tall_mask_lambda = tall_mask_lambda
|
|
80
|
+
self.debug = debug
|
|
81
|
+
self.verbose = verbose
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def run(self, modelpool: BaseModelPool):
|
|
85
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
86
|
+
modelpool = BaseModelPool(models=modelpool)
|
|
87
|
+
|
|
88
|
+
log.info("Compressing models using tall mask task arithmetic.")
|
|
89
|
+
task_vector = None
|
|
90
|
+
with self.profile("load model"):
|
|
91
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
92
|
+
|
|
93
|
+
task_vectors = {}
|
|
94
|
+
models = {}
|
|
95
|
+
for model_name in modelpool.model_names:
|
|
96
|
+
with self.profile("load model"):
|
|
97
|
+
model = modelpool.load_model(model_name)
|
|
98
|
+
for layer_name, layer in model.state_dict(keep_vars=True).items():
|
|
99
|
+
if self.verbose >= 1:
|
|
100
|
+
log.info(f"{layer_name} | {layer.shape}")
|
|
101
|
+
task_vector = state_dict_sub(
|
|
102
|
+
model.state_dict(keep_vars=True),
|
|
103
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
104
|
+
)
|
|
105
|
+
task_vectors[model_name] = task_vector
|
|
106
|
+
|
|
107
|
+
multi_task_vector = state_dict_sum(list(task_vectors.values()))
|
|
108
|
+
|
|
109
|
+
tall_masks = {model: {} for model in modelpool.model_names}
|
|
110
|
+
|
|
111
|
+
for model_name in modelpool.model_names:
|
|
112
|
+
tall_mask = generate_task_masks(
|
|
113
|
+
multi_task_vector,
|
|
114
|
+
task_vectors[model_name],
|
|
115
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
116
|
+
tall_mask_lambda=self.tall_mask_lambda,
|
|
117
|
+
)
|
|
118
|
+
tall_masks[model_name] = tall_mask
|
|
119
|
+
|
|
120
|
+
with self.profile("compress and retrieve"):
|
|
121
|
+
for model_name in modelpool.model_names:
|
|
122
|
+
retrieved_task_vector = state_dict_hadmard_product(
|
|
123
|
+
tall_masks[model_name], multi_task_vector
|
|
124
|
+
)
|
|
125
|
+
retrieved_state_dict = state_dict_add(
|
|
126
|
+
pretrained_model.state_dict(keep_vars=True), retrieved_task_vector
|
|
127
|
+
)
|
|
128
|
+
retrieved_model = deepcopy(pretrained_model)
|
|
129
|
+
retrieved_model.load_state_dict(retrieved_state_dict)
|
|
130
|
+
models[model_name] = retrieved_model
|
|
131
|
+
|
|
132
|
+
self.print_profile_summary()
|
|
133
|
+
return {"models": models, "metadata": None}
|