fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
|
@@ -16,49 +16,9 @@ from fusion_bench.method import BaseAlgorithm
|
|
|
16
16
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
17
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def get_param_names_to_merge(
|
|
23
|
-
input_param_names: List[str], exclude_param_names_regex: list
|
|
24
|
-
):
|
|
25
|
-
"""
|
|
26
|
-
get the names of parameters that need to be merged
|
|
27
|
-
:param input_param_names: list, names of input parameters
|
|
28
|
-
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
29
|
-
:return:
|
|
30
|
-
"""
|
|
31
|
-
param_names_to_merge = []
|
|
32
|
-
for param_name in input_param_names:
|
|
33
|
-
exclude = any(
|
|
34
|
-
[
|
|
35
|
-
re.match(exclude_pattern, param_name)
|
|
36
|
-
for exclude_pattern in exclude_param_names_regex
|
|
37
|
-
]
|
|
38
|
-
)
|
|
39
|
-
if not exclude:
|
|
40
|
-
param_names_to_merge.append(param_name)
|
|
41
|
-
return param_names_to_merge
|
|
42
|
-
|
|
19
|
+
from .utils import get_modules_to_merge, get_param_names_to_merge
|
|
43
20
|
|
|
44
|
-
|
|
45
|
-
"""
|
|
46
|
-
get the model modules that need to be merged, whose type is in include_module_types
|
|
47
|
-
:param model: nn.Module, input model
|
|
48
|
-
:param include_module_types: list, module types that want to include
|
|
49
|
-
:return:
|
|
50
|
-
"""
|
|
51
|
-
modules_to_merge: Dict[str, nn.Module] = {}
|
|
52
|
-
for module_name, module in model.named_modules():
|
|
53
|
-
is_valid_type = not include_module_types or any(
|
|
54
|
-
[
|
|
55
|
-
isinstance(module, include_module_type)
|
|
56
|
-
for include_module_type in include_module_types
|
|
57
|
-
]
|
|
58
|
-
)
|
|
59
|
-
if is_valid_type:
|
|
60
|
-
modules_to_merge[module_name] = module
|
|
61
|
-
return modules_to_merge
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
62
22
|
|
|
63
23
|
|
|
64
24
|
def reduce_non_diagonal_elements(
|
|
@@ -88,12 +48,16 @@ def merging_with_regmean_weights(
|
|
|
88
48
|
):
|
|
89
49
|
"""
|
|
90
50
|
merge parameters of different models with computed regmean weights
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
54
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
55
|
+
models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
56
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
57
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: merged model parameters
|
|
97
61
|
"""
|
|
98
62
|
# dict, dictionary of model parameters
|
|
99
63
|
merged_params = {}
|
|
@@ -164,13 +128,17 @@ def regmean_merging(
|
|
|
164
128
|
reduce_non_diagonal_ratio: float = 1.0,
|
|
165
129
|
):
|
|
166
130
|
"""
|
|
167
|
-
regmean merging method
|
|
168
|
-
|
|
169
|
-
:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
131
|
+
regmean merging method.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
models_to_merge: list, individual models that need to be merged
|
|
135
|
+
trainers: list, trainers of individual models
|
|
136
|
+
exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
137
|
+
nums_regmean_examples: list, numbers of examples to compute regmean weights
|
|
138
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
dict: merged model parameters
|
|
174
142
|
"""
|
|
175
143
|
|
|
176
144
|
def compute_regmean_weights(module_name: str):
|
|
@@ -281,7 +249,10 @@ def regmean_merging(
|
|
|
281
249
|
|
|
282
250
|
|
|
283
251
|
@auto_register_config
|
|
284
|
-
class RegMeanAlgorithm(
|
|
252
|
+
class RegMeanAlgorithm(
|
|
253
|
+
SimpleProfilerMixin,
|
|
254
|
+
BaseAlgorithm,
|
|
255
|
+
):
|
|
285
256
|
_include_module_type = [nn.Linear]
|
|
286
257
|
|
|
287
258
|
def __init__(
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, List
|
|
3
|
+
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_param_names_to_merge(
|
|
8
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
9
|
+
) -> List[str]:
|
|
10
|
+
"""
|
|
11
|
+
get the names of parameters that need to be merged
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
input_param_names: list, names of input parameters
|
|
15
|
+
exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
list: names of parameters that need to be merged
|
|
19
|
+
"""
|
|
20
|
+
param_names_to_merge = []
|
|
21
|
+
for param_name in input_param_names:
|
|
22
|
+
exclude = any(
|
|
23
|
+
[
|
|
24
|
+
re.match(exclude_pattern, param_name)
|
|
25
|
+
for exclude_pattern in exclude_param_names_regex
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
if not exclude:
|
|
29
|
+
param_names_to_merge.append(param_name)
|
|
30
|
+
return param_names_to_merge
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_modules_to_merge(
|
|
34
|
+
model: nn.Module, include_module_types: list
|
|
35
|
+
) -> Dict[str, nn.Module]:
|
|
36
|
+
"""
|
|
37
|
+
get the model modules that need to be merged, whose type is in include_module_types
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model: nn.Module, input model
|
|
41
|
+
include_module_types: list, module types that want to include
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dict[str, nn.Module]: a dictionary of modules to merge
|
|
45
|
+
"""
|
|
46
|
+
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
+
for module_name, module in model.named_modules():
|
|
48
|
+
is_valid_type = not include_module_types or any(
|
|
49
|
+
[
|
|
50
|
+
isinstance(module, include_module_type)
|
|
51
|
+
for include_module_type in include_module_types
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
if is_valid_type:
|
|
55
|
+
modules_to_merge[module_name] = module
|
|
56
|
+
return modules_to_merge
|
|
@@ -7,55 +7,14 @@ import torch
|
|
|
7
7
|
from torch import Tensor, nn
|
|
8
8
|
from tqdm.autonotebook import tqdm
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
import fusion_bench.method.regmean.utils as regmean_utils
|
|
11
|
+
from fusion_bench import BaseAlgorithm, auto_register_config
|
|
11
12
|
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
13
14
|
|
|
14
15
|
log = logging.getLogger(__name__)
|
|
15
16
|
|
|
16
17
|
|
|
17
|
-
def get_param_names_to_merge(
|
|
18
|
-
input_param_names: List[str], exclude_param_names_regex: list
|
|
19
|
-
):
|
|
20
|
-
"""
|
|
21
|
-
get the names of parameters that need to be merged
|
|
22
|
-
:param input_param_names: list, names of input parameters
|
|
23
|
-
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
24
|
-
:return:
|
|
25
|
-
"""
|
|
26
|
-
param_names_to_merge = []
|
|
27
|
-
for param_name in input_param_names:
|
|
28
|
-
exclude = any(
|
|
29
|
-
[
|
|
30
|
-
re.match(exclude_pattern, param_name)
|
|
31
|
-
for exclude_pattern in exclude_param_names_regex
|
|
32
|
-
]
|
|
33
|
-
)
|
|
34
|
-
if not exclude:
|
|
35
|
-
param_names_to_merge.append(param_name)
|
|
36
|
-
return param_names_to_merge
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def get_modules_to_merge(model: nn.Module, include_module_types: list):
|
|
40
|
-
"""
|
|
41
|
-
get the model modules that need to be merged, whose type is in include_module_types
|
|
42
|
-
:param model: nn.Module, input model
|
|
43
|
-
:param include_module_types: list, module types that want to include
|
|
44
|
-
:return:
|
|
45
|
-
"""
|
|
46
|
-
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
-
for module_name, module in model.named_modules():
|
|
48
|
-
is_valid_type = not include_module_types or any(
|
|
49
|
-
[
|
|
50
|
-
isinstance(module, include_module_type)
|
|
51
|
-
for include_module_type in include_module_types
|
|
52
|
-
]
|
|
53
|
-
)
|
|
54
|
-
if is_valid_type:
|
|
55
|
-
modules_to_merge[module_name] = module
|
|
56
|
-
return modules_to_merge
|
|
57
|
-
|
|
58
|
-
|
|
59
18
|
def reduce_non_diagonal_elements(
|
|
60
19
|
regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
|
|
61
20
|
):
|
|
@@ -130,12 +89,16 @@ def merging_with_regmean_weights(
|
|
|
130
89
|
):
|
|
131
90
|
"""
|
|
132
91
|
merge parameters of different models with computed regmean weights
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
92
|
+
|
|
93
|
+
Asrgs:
|
|
94
|
+
models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
95
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
96
|
+
models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
97
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
98
|
+
reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
dict: merged model parameters
|
|
139
102
|
"""
|
|
140
103
|
# dict, dictionary of model parameters
|
|
141
104
|
merged_params = {}
|
|
@@ -176,14 +139,12 @@ def merging_with_regmean_weights(
|
|
|
176
139
|
return merged_params
|
|
177
140
|
|
|
178
141
|
|
|
179
|
-
|
|
142
|
+
@auto_register_config
|
|
143
|
+
class RegMeanAlgorithmPlusPlus(
|
|
144
|
+
SimpleProfilerMixin,
|
|
145
|
+
BaseAlgorithm,
|
|
146
|
+
):
|
|
180
147
|
_include_module_type = [nn.Linear]
|
|
181
|
-
_config_mapping = {
|
|
182
|
-
"num_regmean_examples": "num_regmean_examples",
|
|
183
|
-
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
184
|
-
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
185
|
-
"weight_transpose": "weight_transpose",
|
|
186
|
-
}
|
|
187
148
|
|
|
188
149
|
def __init__(
|
|
189
150
|
self,
|
|
@@ -194,11 +155,11 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
194
155
|
weight_transpose: bool,
|
|
195
156
|
**kwargs,
|
|
196
157
|
):
|
|
158
|
+
super().__init__(**kwargs)
|
|
197
159
|
self.num_regmean_examples = num_regmean_examples
|
|
198
160
|
self.exclude_param_names_regex = exclude_param_names_regex
|
|
199
161
|
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
200
162
|
self.weight_transpose = weight_transpose
|
|
201
|
-
super().__init__(**kwargs)
|
|
202
163
|
|
|
203
164
|
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
204
165
|
if not isinstance(modelpool, BaseModelPool):
|
|
@@ -262,7 +223,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
262
223
|
|
|
263
224
|
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
264
225
|
if param_names_to_merge is None:
|
|
265
|
-
param_names_to_merge = get_param_names_to_merge(
|
|
226
|
+
param_names_to_merge = regmean_utils.get_param_names_to_merge(
|
|
266
227
|
input_param_names=list(param_dict.keys()),
|
|
267
228
|
exclude_param_names_regex=self.config.get(
|
|
268
229
|
"exclude_param_names_regex", []
|
|
@@ -274,7 +235,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
274
235
|
param_dict[param_name]
|
|
275
236
|
)
|
|
276
237
|
|
|
277
|
-
linear_modules_to_merge = get_modules_to_merge(
|
|
238
|
+
linear_modules_to_merge = regmean_utils.get_modules_to_merge(
|
|
278
239
|
model=layer_to_merge,
|
|
279
240
|
include_module_types=self._include_module_type,
|
|
280
241
|
)
|
|
@@ -294,7 +255,7 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
294
255
|
linear_modules_to_merge=linear_modules_to_merge,
|
|
295
256
|
)
|
|
296
257
|
|
|
297
|
-
module_subset = get_param_names_to_merge(
|
|
258
|
+
module_subset = regmean_utils.get_param_names_to_merge(
|
|
298
259
|
input_param_names=list(param_dict.keys()),
|
|
299
260
|
exclude_param_names_regex=self.exclude_param_names_regex,
|
|
300
261
|
)
|
|
@@ -89,7 +89,7 @@ class SimpleAverageAlgorithm(
|
|
|
89
89
|
modelpool = BaseModelPool(modelpool)
|
|
90
90
|
|
|
91
91
|
log.info(
|
|
92
|
-
f"Fusing models using simple average on {len(modelpool.model_names)} models."
|
|
92
|
+
f"Fusing models using simple average on {len(modelpool.model_names)} models. "
|
|
93
93
|
f"models: {modelpool.model_names}"
|
|
94
94
|
)
|
|
95
95
|
sd: Optional[StateDictType] = None
|
|
@@ -119,7 +119,7 @@ class SimpleAverageAlgorithm(
|
|
|
119
119
|
|
|
120
120
|
if isinstance(forward_model, LazyStateDict):
|
|
121
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
122
|
-
forward_model = forward_model.meta_module.to_empty(
|
|
122
|
+
forward_model = deepcopy(forward_model.meta_module).to_empty(
|
|
123
123
|
device=forward_model._device
|
|
124
124
|
)
|
|
125
125
|
result = forward_model.load_state_dict(sd, strict=False)
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
from .slerp import SlerpMergeAlgorithm
|
|
2
|
+
from .slerp import SlerpForCausalLM, SlerpMergeAlgorithm
|
|
@@ -1,16 +1,24 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
3
5
|
|
|
4
6
|
import torch
|
|
5
7
|
from torch import nn
|
|
8
|
+
from tqdm import tqdm
|
|
6
9
|
from typing_extensions import override
|
|
7
10
|
|
|
11
|
+
from fusion_bench import LazyStateDict, create_default_model_card, timeit_context
|
|
8
12
|
from fusion_bench.method import BaseAlgorithm
|
|
9
|
-
from fusion_bench.
|
|
13
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
14
|
+
from fusion_bench.modelpool import BaseModelPool, CausalLMPool
|
|
10
15
|
from fusion_bench.utils.type import StateDictType
|
|
11
16
|
|
|
12
17
|
from .slerp_utils import slerp
|
|
13
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from transformers import PreTrainedModel
|
|
21
|
+
|
|
14
22
|
log = logging.getLogger(__name__)
|
|
15
23
|
|
|
16
24
|
|
|
@@ -21,6 +29,7 @@ def slerp_on_state_dicts(
|
|
|
21
29
|
*,
|
|
22
30
|
DOT_THRESHOLD: float = 0.9995,
|
|
23
31
|
epsilon: float = 1e-8,
|
|
32
|
+
show_pbar: bool = False,
|
|
24
33
|
) -> StateDictType:
|
|
25
34
|
"""
|
|
26
35
|
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
@@ -36,7 +45,8 @@ def slerp_on_state_dicts(
|
|
|
36
45
|
dict: The interpolated state dictionary.
|
|
37
46
|
"""
|
|
38
47
|
state_dict = {}
|
|
39
|
-
|
|
48
|
+
pbar = secondary_state_dict if not show_pbar else tqdm(secondary_state_dict)
|
|
49
|
+
for key in pbar:
|
|
40
50
|
v0 = primary_state_dict[key]
|
|
41
51
|
v1 = secondary_state_dict[key]
|
|
42
52
|
if v0.shape != v1.shape:
|
|
@@ -49,18 +59,19 @@ def slerp_on_state_dicts(
|
|
|
49
59
|
return state_dict
|
|
50
60
|
|
|
51
61
|
|
|
62
|
+
@auto_register_config
|
|
52
63
|
class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
53
64
|
"""
|
|
54
65
|
General purpose implementation of Slerp (Spherical Linear Interpolation) for PyTorch models.
|
|
55
66
|
"""
|
|
56
67
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
t: float,
|
|
71
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
72
|
+
epsilon: float = 1e-8,
|
|
73
|
+
**kwargs,
|
|
74
|
+
):
|
|
64
75
|
"""
|
|
65
76
|
Initialize the SlerpMergeAlgorithm.
|
|
66
77
|
|
|
@@ -69,10 +80,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
69
80
|
DOT_THRESHOLD (float, optional): The threshold for the dot product of the two vectors. Defaults to 0.9995.
|
|
70
81
|
epsilon (float, optional): The epsilon value for numerical stability. Defaults to 1e-8.
|
|
71
82
|
"""
|
|
72
|
-
|
|
73
|
-
self.DOT_THRESHOLD = DOT_THRESHOLD
|
|
74
|
-
self.epsilon = epsilon
|
|
75
|
-
super().__init__()
|
|
83
|
+
super().__init__(**kwargs)
|
|
76
84
|
|
|
77
85
|
@override
|
|
78
86
|
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
@@ -102,3 +110,91 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
102
110
|
|
|
103
111
|
primary_model.load_state_dict(state_dict)
|
|
104
112
|
return primary_model
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@auto_register_config
|
|
116
|
+
class SlerpForCausalLM(
|
|
117
|
+
SimpleProfilerMixin,
|
|
118
|
+
BaseAlgorithm,
|
|
119
|
+
):
|
|
120
|
+
"""
|
|
121
|
+
Slerp (Spherical Linear Interpolation) for Causal Language Models.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
t: float,
|
|
127
|
+
DOT_THRESHOLD: float = 0.9995,
|
|
128
|
+
epsilon: float = 1e-8,
|
|
129
|
+
model_save_path: Optional[str] = None,
|
|
130
|
+
show_pbar: bool = False,
|
|
131
|
+
**kwargs,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Initialize the SlerpForCausalLM algorithm.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
t (float): The interpolation parameter. Must be in the range [0, 1].
|
|
138
|
+
t=0 returns the first model, t=1 returns the second model,
|
|
139
|
+
t=0.5 provides balanced interpolation.
|
|
140
|
+
DOT_THRESHOLD (float, optional): The threshold for the dot product of normalized vectors.
|
|
141
|
+
When the absolute dot product exceeds this threshold,
|
|
142
|
+
vectors are considered nearly collinear and linear
|
|
143
|
+
interpolation (LERP) is used instead of SLERP for
|
|
144
|
+
numerical stability. Defaults to 0.9995.
|
|
145
|
+
epsilon (float, optional): Small value used for numerical stability to avoid
|
|
146
|
+
division by zero during vector normalization.
|
|
147
|
+
Defaults to 1e-8.
|
|
148
|
+
model_save_path (Optional[str], optional): Path where the merged model should be saved.
|
|
149
|
+
If None, the model is not saved to disk.
|
|
150
|
+
Defaults to None.
|
|
151
|
+
show_pbar (bool, optional): Whether to display a progress bar during the interpolation
|
|
152
|
+
process. Useful for debugging or monitoring progress with
|
|
153
|
+
large models. Defaults to False.
|
|
154
|
+
**kwargs: Additional keyword arguments passed to the parent BaseAlgorithm class.
|
|
155
|
+
"""
|
|
156
|
+
super().__init__(**kwargs)
|
|
157
|
+
|
|
158
|
+
@override
|
|
159
|
+
def run(self, modelpool: CausalLMPool):
|
|
160
|
+
assert len(modelpool.all_model_names) == 2, "Slerp expect exactly 2 models"
|
|
161
|
+
primary_model = modelpool.load_model(modelpool.all_model_names[0])
|
|
162
|
+
secondary_model = modelpool.load_model(modelpool.all_model_names[1])
|
|
163
|
+
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
primary_state_dict = primary_model.state_dict()
|
|
166
|
+
secondary_state_dict = secondary_model.state_dict()
|
|
167
|
+
state_dict = slerp_on_state_dicts(
|
|
168
|
+
self.t,
|
|
169
|
+
primary_state_dict,
|
|
170
|
+
secondary_state_dict,
|
|
171
|
+
DOT_THRESHOLD=self.DOT_THRESHOLD,
|
|
172
|
+
epsilon=self.epsilon,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if isinstance(primary_model, nn.Module):
|
|
176
|
+
model = primary_model
|
|
177
|
+
model.load_state_dict(state_dict)
|
|
178
|
+
elif isinstance(primary_model, LazyStateDict):
|
|
179
|
+
model: "PreTrainedModel" = deepcopy(primary_model.meta_module)
|
|
180
|
+
model.to(device=primary_model._device)
|
|
181
|
+
model.load_state_dict(state_dict)
|
|
182
|
+
else:
|
|
183
|
+
raise TypeError(
|
|
184
|
+
f"Unsupported model type: {type(primary_model)}. "
|
|
185
|
+
"Expected nn.Module or LazyStateDict."
|
|
186
|
+
)
|
|
187
|
+
if self.model_save_path is not None:
|
|
188
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
189
|
+
tokenizer = modelpool.load_tokenizer()
|
|
190
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
191
|
+
model.save_pretrained(self.model_save_path)
|
|
192
|
+
model_card_str = create_default_model_card(
|
|
193
|
+
models=[modelpool.get_model_path(m) for m in modelpool.model_names],
|
|
194
|
+
description="Merged model using Slerp.",
|
|
195
|
+
algorithm_config=self.config,
|
|
196
|
+
modelpool_config=modelpool.config,
|
|
197
|
+
)
|
|
198
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
199
|
+
f.write(model_card_str)
|
|
200
|
+
return model
|
|
@@ -6,11 +6,20 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import
|
|
9
|
+
from typing import ( # noqa: F401
|
|
10
|
+
TYPE_CHECKING,
|
|
11
|
+
Dict,
|
|
12
|
+
List,
|
|
13
|
+
Mapping,
|
|
14
|
+
Optional,
|
|
15
|
+
TypeVar,
|
|
16
|
+
Union,
|
|
17
|
+
)
|
|
10
18
|
|
|
11
19
|
import torch
|
|
12
20
|
from torch import nn
|
|
13
21
|
|
|
22
|
+
from fusion_bench import LazyStateDict
|
|
14
23
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
24
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
25
|
from fusion_bench.modelpool import BaseModelPool
|
|
@@ -21,6 +30,8 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
21
30
|
)
|
|
22
31
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
32
|
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from transformers import PreTrainedModel
|
|
24
35
|
log = logging.getLogger(__name__)
|
|
25
36
|
|
|
26
37
|
|
|
@@ -125,25 +136,39 @@ class TaskArithmeticAlgorithm(
|
|
|
125
136
|
with self.profile("merge weights"):
|
|
126
137
|
if task_vector is None:
|
|
127
138
|
task_vector = state_dict_sub(
|
|
128
|
-
model.state_dict(
|
|
129
|
-
pretrained_model.state_dict(
|
|
139
|
+
model.state_dict(),
|
|
140
|
+
pretrained_model.state_dict(),
|
|
130
141
|
)
|
|
131
142
|
else:
|
|
132
143
|
task_vector = state_dict_add(
|
|
133
144
|
task_vector,
|
|
134
145
|
state_dict_sub(
|
|
135
|
-
model.state_dict(
|
|
136
|
-
pretrained_model.state_dict(
|
|
146
|
+
model.state_dict(),
|
|
147
|
+
pretrained_model.state_dict(),
|
|
137
148
|
),
|
|
138
149
|
)
|
|
139
150
|
with self.profile("merge weights"):
|
|
140
151
|
# scale the task vector
|
|
141
152
|
task_vector = state_dict_mul(task_vector, self.config.scaling_factor)
|
|
142
153
|
# add the task vector to the pretrained model
|
|
143
|
-
state_dict = state_dict_add(
|
|
144
|
-
pretrained_model.state_dict(keep_vars=True), task_vector
|
|
145
|
-
)
|
|
154
|
+
state_dict = state_dict_add(pretrained_model.state_dict(), task_vector)
|
|
146
155
|
|
|
147
156
|
self.print_profile_summary()
|
|
148
|
-
|
|
149
|
-
|
|
157
|
+
|
|
158
|
+
# apply state dict to model
|
|
159
|
+
if isinstance(pretrained_model, nn.Module):
|
|
160
|
+
model = pretrained_model
|
|
161
|
+
model.load_state_dict(state_dict)
|
|
162
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
163
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
164
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
165
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
166
|
+
if result.unexpected_keys:
|
|
167
|
+
raise ValueError(
|
|
168
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
169
|
+
)
|
|
170
|
+
if result.missing_keys:
|
|
171
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
172
|
+
else:
|
|
173
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
174
|
+
return model
|
|
@@ -9,11 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
+
from copy import deepcopy
|
|
12
13
|
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
14
|
|
|
14
15
|
import torch
|
|
15
16
|
from torch import Tensor, nn
|
|
17
|
+
from transformers import PreTrainedModel
|
|
16
18
|
|
|
19
|
+
from fusion_bench import LazyStateDict
|
|
17
20
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
21
|
from fusion_bench.method import BaseAlgorithm
|
|
19
22
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
@@ -98,12 +101,25 @@ class TiesMergingAlgorithm(
|
|
|
98
101
|
merge_func=merge_func,
|
|
99
102
|
)
|
|
100
103
|
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
101
|
-
|
|
104
|
+
state_dict = vector_to_state_dict(
|
|
102
105
|
merged_check, ptm_check, remove_keys=remove_keys
|
|
103
106
|
)
|
|
104
|
-
|
|
105
|
-
# Load the merged state dict into the pretrained model
|
|
106
|
-
pretrained_model.load_state_dict(merged_state_dict)
|
|
107
|
-
|
|
108
107
|
self.print_profile_summary()
|
|
109
|
-
|
|
108
|
+
|
|
109
|
+
# apply state dict to model
|
|
110
|
+
if isinstance(pretrained_model, nn.Module):
|
|
111
|
+
model = pretrained_model
|
|
112
|
+
model.load_state_dict(state_dict)
|
|
113
|
+
elif isinstance(pretrained_model, LazyStateDict):
|
|
114
|
+
model = deepcopy(pretrained_model.meta_module)
|
|
115
|
+
model = model.to_empty(device=pretrained_model._device)
|
|
116
|
+
result = model.load_state_dict(state_dict, strict=False)
|
|
117
|
+
if result.unexpected_keys:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
120
|
+
)
|
|
121
|
+
if result.missing_keys:
|
|
122
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
123
|
+
else:
|
|
124
|
+
raise TypeError(f"Unsupported model type: {type(pretrained_model)}")
|
|
125
|
+
return model
|