fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Literal, Union, cast
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import Tensor
|
|
7
|
+
from tqdm.auto import tqdm
|
|
7
8
|
|
|
8
9
|
from .parameters import check_parameters_all_equal
|
|
9
10
|
from .type import BoolStateDictType, StateDictType
|
|
@@ -43,7 +44,7 @@ def state_dicts_check_keys(state_dicts: List[StateDictType]):
|
|
|
43
44
|
assert keys == set(state_dict.keys()), "keys of state_dicts are not equal"
|
|
44
45
|
|
|
45
46
|
|
|
46
|
-
def num_params_of_state_dict(state_dict: StateDictType):
|
|
47
|
+
def num_params_of_state_dict(state_dict: StateDictType) -> int:
|
|
47
48
|
"""
|
|
48
49
|
Returns the number of parameters in a state dict.
|
|
49
50
|
|
|
@@ -56,7 +57,7 @@ def num_params_of_state_dict(state_dict: StateDictType):
|
|
|
56
57
|
return sum([state_dict[key].numel() for key in state_dict])
|
|
57
58
|
|
|
58
59
|
|
|
59
|
-
def state_dict_flatten(state_dict: Dict[str, Tensor]):
|
|
60
|
+
def state_dict_flatten(state_dict: Dict[str, Tensor]) -> Tensor:
|
|
60
61
|
"""
|
|
61
62
|
Flattens a state dict.
|
|
62
63
|
|
|
@@ -72,7 +73,7 @@ def state_dict_flatten(state_dict: Dict[str, Tensor]):
|
|
|
72
73
|
return torch.cat(flattened_state_dict)
|
|
73
74
|
|
|
74
75
|
|
|
75
|
-
def state_dict_avg(state_dicts: List[StateDictType]):
|
|
76
|
+
def state_dict_avg(state_dicts: List[StateDictType]) -> StateDictType:
|
|
76
77
|
"""
|
|
77
78
|
Returns the average of a list of state dicts.
|
|
78
79
|
|
|
@@ -99,7 +100,7 @@ def state_dict_avg(state_dicts: List[StateDictType]):
|
|
|
99
100
|
|
|
100
101
|
def state_dict_sub(
|
|
101
102
|
a: StateDictType, b: StateDictType, strict: bool = True, device=None
|
|
102
|
-
):
|
|
103
|
+
) -> StateDictType:
|
|
103
104
|
"""
|
|
104
105
|
Returns the difference between two state dicts `a-b`.
|
|
105
106
|
|
|
@@ -124,8 +125,12 @@ def state_dict_sub(
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
def state_dict_add(
|
|
127
|
-
a: StateDictType,
|
|
128
|
-
|
|
128
|
+
a: StateDictType,
|
|
129
|
+
b: StateDictType,
|
|
130
|
+
strict: bool = True,
|
|
131
|
+
device=None,
|
|
132
|
+
show_pbar: bool = False,
|
|
133
|
+
) -> StateDictType:
|
|
129
134
|
"""
|
|
130
135
|
Returns the sum of two state dicts.
|
|
131
136
|
|
|
@@ -140,10 +145,10 @@ def state_dict_add(
|
|
|
140
145
|
ans = {}
|
|
141
146
|
if strict:
|
|
142
147
|
check_parameters_all_equal([a, b])
|
|
143
|
-
for key in a:
|
|
148
|
+
for key in tqdm(tuple(a.keys())) if show_pbar else a:
|
|
144
149
|
ans[key] = a[key] + b[key]
|
|
145
150
|
else:
|
|
146
|
-
for key in a:
|
|
151
|
+
for key in tqdm(tuple(a.keys())) if show_pbar else a:
|
|
147
152
|
if key in b:
|
|
148
153
|
ans[key] = a[key] + b[key]
|
|
149
154
|
if device is not None:
|
|
@@ -151,14 +156,14 @@ def state_dict_add(
|
|
|
151
156
|
return ans
|
|
152
157
|
|
|
153
158
|
|
|
154
|
-
def state_dict_add_scalar(a: StateDictType, scalar: Number):
|
|
159
|
+
def state_dict_add_scalar(a: StateDictType, scalar: Number) -> StateDictType:
|
|
155
160
|
ans = OrderedDict()
|
|
156
161
|
for key in a:
|
|
157
162
|
ans[key] = a[key] + scalar
|
|
158
163
|
return ans
|
|
159
164
|
|
|
160
165
|
|
|
161
|
-
def state_dict_mul(state_dict: StateDictType, scalar: float):
|
|
166
|
+
def state_dict_mul(state_dict: StateDictType, scalar: float) -> StateDictType:
|
|
162
167
|
"""
|
|
163
168
|
Returns the product of a state dict and a scalar.
|
|
164
169
|
|
|
@@ -175,7 +180,9 @@ def state_dict_mul(state_dict: StateDictType, scalar: float):
|
|
|
175
180
|
return diff
|
|
176
181
|
|
|
177
182
|
|
|
178
|
-
def state_dict_div(
|
|
183
|
+
def state_dict_div(
|
|
184
|
+
state_dict: StateDictType, scalar: float, show_pbar: bool = False
|
|
185
|
+
) -> StateDictType:
|
|
179
186
|
"""
|
|
180
187
|
Returns the division of a state dict by a scalar.
|
|
181
188
|
|
|
@@ -187,21 +194,21 @@ def state_dict_div(state_dict: StateDictType, scalar: float):
|
|
|
187
194
|
Dict: The division of the state dict by the scalar.
|
|
188
195
|
"""
|
|
189
196
|
diff = OrderedDict()
|
|
190
|
-
for k in state_dict:
|
|
197
|
+
for k in tqdm(tuple(state_dict.keys())) if show_pbar else state_dict:
|
|
191
198
|
diff[k] = state_dict[k] / scalar
|
|
192
199
|
return diff
|
|
193
200
|
|
|
194
201
|
|
|
195
|
-
def state_dict_power(state_dict:
|
|
202
|
+
def state_dict_power(state_dict: StateDictType, p: float) -> StateDictType:
|
|
196
203
|
"""
|
|
197
204
|
Returns the power of a state dict.
|
|
198
205
|
|
|
199
206
|
Args:
|
|
200
|
-
state_dict (
|
|
207
|
+
state_dict (StateDictType): The state dict to be powered.
|
|
201
208
|
p (float): The power to raise the state dict to.
|
|
202
209
|
|
|
203
210
|
Returns:
|
|
204
|
-
|
|
211
|
+
StateDictType: The powered state dict.
|
|
205
212
|
"""
|
|
206
213
|
powered_state_dict = {}
|
|
207
214
|
for key in state_dict:
|
|
@@ -210,17 +217,17 @@ def state_dict_power(state_dict: Dict[str, Tensor], p: float):
|
|
|
210
217
|
|
|
211
218
|
|
|
212
219
|
def state_dict_interpolation(
|
|
213
|
-
state_dicts: List[
|
|
214
|
-
):
|
|
220
|
+
state_dicts: List[StateDictType], scalars: List[float]
|
|
221
|
+
) -> StateDictType:
|
|
215
222
|
"""
|
|
216
223
|
Interpolates between a list of state dicts using a list of scalars.
|
|
217
224
|
|
|
218
225
|
Args:
|
|
219
|
-
state_dicts (List[
|
|
226
|
+
state_dicts (List[StateDictType]): The list of state dicts to interpolate between.
|
|
220
227
|
scalars (List[float]): The list of scalars to use for interpolation.
|
|
221
228
|
|
|
222
229
|
Returns:
|
|
223
|
-
|
|
230
|
+
StateDictType: The interpolated state dict.
|
|
224
231
|
"""
|
|
225
232
|
assert len(state_dicts) == len(
|
|
226
233
|
scalars
|
|
@@ -238,15 +245,15 @@ def state_dict_interpolation(
|
|
|
238
245
|
return interpolated_state_dict
|
|
239
246
|
|
|
240
247
|
|
|
241
|
-
def state_dict_sum(state_dicts: List[StateDictType]):
|
|
248
|
+
def state_dict_sum(state_dicts: List[StateDictType]) -> StateDictType:
|
|
242
249
|
"""
|
|
243
250
|
Returns the sum of a list of state dicts.
|
|
244
251
|
|
|
245
252
|
Args:
|
|
246
|
-
state_dicts (List[
|
|
253
|
+
state_dicts (List[StateDictType]): The list of state dicts to sum.
|
|
247
254
|
|
|
248
255
|
Returns:
|
|
249
|
-
|
|
256
|
+
StateDictType: The sum of the state dicts.
|
|
250
257
|
"""
|
|
251
258
|
assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
|
|
252
259
|
assert all(
|
|
@@ -262,17 +269,17 @@ def state_dict_sum(state_dicts: List[StateDictType]):
|
|
|
262
269
|
|
|
263
270
|
|
|
264
271
|
def state_dict_weighted_sum(
|
|
265
|
-
state_dicts: List[
|
|
266
|
-
):
|
|
272
|
+
state_dicts: List[StateDictType], weights: List[float], device=None
|
|
273
|
+
) -> StateDictType:
|
|
267
274
|
"""
|
|
268
275
|
Returns the weighted sum of a list of state dicts.
|
|
269
276
|
|
|
270
277
|
Args:
|
|
271
|
-
state_dicts (List[
|
|
278
|
+
state_dicts (List[StateDictType]): The list of state dicts to interpolate between.
|
|
272
279
|
weights (List[float]): The list of weights to use for the weighted sum.
|
|
273
280
|
|
|
274
281
|
Returns:
|
|
275
|
-
|
|
282
|
+
StateDictType: The weighted sum of the state dicts.
|
|
276
283
|
"""
|
|
277
284
|
assert len(state_dicts) == len(
|
|
278
285
|
weights
|
|
@@ -297,7 +304,7 @@ def state_dict_weighted_sum(
|
|
|
297
304
|
return weighted_sum_state_dict
|
|
298
305
|
|
|
299
306
|
|
|
300
|
-
def state_dict_diff_abs(a: StateDictType, b: StateDictType):
|
|
307
|
+
def state_dict_diff_abs(a: StateDictType, b: StateDictType) -> StateDictType:
|
|
301
308
|
"""
|
|
302
309
|
Returns the per-layer abs of the difference between two state dicts.
|
|
303
310
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.21
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,13 +45,17 @@ Requires-Dist: rich
|
|
|
45
45
|
Requires-Dist: scipy
|
|
46
46
|
Requires-Dist: h5py
|
|
47
47
|
Requires-Dist: pytest
|
|
48
|
+
Requires-Dist: transformers!=4.49
|
|
49
|
+
Requires-Dist: pillow!=11.2.1
|
|
48
50
|
Provides-Extra: lm-eval-harness
|
|
49
51
|
Requires-Dist: lm-eval; extra == "lm-eval-harness"
|
|
52
|
+
Requires-Dist: immutabledict; extra == "lm-eval-harness"
|
|
53
|
+
Requires-Dist: langdetect; extra == "lm-eval-harness"
|
|
50
54
|
Dynamic: license-file
|
|
51
55
|
|
|
52
56
|
<div align='center'>
|
|
53
57
|
|
|
54
|
-
# FusionBench: A Comprehensive Benchmark/
|
|
58
|
+
# FusionBench: A Comprehensive Benchmark/Toolkit of Deep Model Fusion
|
|
55
59
|
|
|
56
60
|
[](http://arxiv.org/abs/2406.03280)
|
|
57
61
|
[](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
|
|
@@ -72,6 +76,14 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
|
|
|
72
76
|
|
|
73
77
|
Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
|
|
74
78
|
|
|
79
|
+
<details>
|
|
80
|
+
<summary>The-Hai Nguyen, Dang Huu-Tien, Takeshi Suzuki, and Le-Minh Nguyen. RegMean++: Enhancing Effectiveness and Generalization of Regression Mean for Model Merging. Aug, 2025. https://www.arxiv.org/abs/2508.03121</summary>
|
|
81
|
+
|
|
82
|
+
Regression Mean (RegMean), an approach that formulates model merging as a linear regression problem, aims to find the optimal weights for each linear layer in the merge model by minimizing the discrepancy in predictions between the merge and candidate models. RegMean provides a precise closed-form solution for the merging problem; therefore, it offers explainability and computational efficiency. However, RegMean merges each linear layer independently, overlooking how the features and information in the earlier layers propagate through the layers and influence the final prediction in the merge model. In this paper, we introduce RegMean++, a simple yet effective alternative to RegMean, that explicitly incorporates both intra- and cross-layer dependencies between merge models' layers into RegMean's objective. By accounting for these dependencies, RegMean++ better captures the behaviors of the merge model. Extensive experiments demonstrate that RegMean++ consistently outperforms RegMean across diverse settings, including in-domain (ID) and out-of-domain (OOD) generalization, sequential merging, large-scale tasks, and robustness under several types of distribution shifts. Furthermore, RegMean++ achieves competitive or state-of-the-art performance compared to various recent advanced model merging methods.
|
|
83
|
+
|
|
84
|
+
<img width="1000" alt="image" src="docs/algorithms/images/regmean_vs_regmean_plusplus.png">
|
|
85
|
+
</details>
|
|
86
|
+
|
|
75
87
|
<details>
|
|
76
88
|
<summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
|
|
77
89
|
|
|
@@ -81,7 +93,7 @@ Model merging has emerged as a promising approach for multi-task learning (MTL),
|
|
|
81
93
|
<details>
|
|
82
94
|
<summary>Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. Feb 2025. https://arxiv.org/abs/2502.04959</summary>
|
|
83
95
|
|
|
84
|
-
Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
|
|
96
|
+
Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
|
|
85
97
|
</details>
|
|
86
98
|
|
|
87
99
|
<details>
|
|
@@ -99,12 +111,12 @@ Merging multiple expert models offers a promising approach for performing multi-
|
|
|
99
111
|
<details>
|
|
100
112
|
<summary>Hongling Zheng, Li Shen, Anke Tang, Yong Luo et al. Learn From Model Beyond Fine-Tuning: A Survey. Nature Machine Intelligence. Jan, 2025. https://www.nature.com/articles/s42256-024-00961-0</summary>
|
|
101
113
|
|
|
102
|
-
> Foundation models (FM) have demonstrated remarkable performance across a wide range of tasks (especially in the fields of natural language processing and computer vision), primarily attributed to their ability to comprehend instructions and access extensive, high-quality data. This not only showcases their current effectiveness but also sets a promising trajectory towards the development of artificial general intelligence. Unfortunately, due to multiple constraints, the raw data of the model used for large model training are often inaccessible, so the use of end-to-end models for downstream tasks has become a new research trend, which we call Learn From Model (LFM) in this article. LFM focuses on the research, modification, and design of FM based on the model interface, so as to better understand the model structure and weights (in a black box environment), and to generalize the model to downstream tasks. The study of LFM techniques can be broadly categorized into five major areas: model tuning, model distillation, model reuse, meta learning and model editing. Each category encompasses a repertoire of methods and strategies that aim to enhance the capabilities and performance of FM. This paper gives a comprehensive review of the current methods based on FM from the perspective of LFM, in order to help readers better understand the current research status and ideas. To conclude, we summarize the survey by highlighting several critical areas for future exploration and addressing open issues that require further attention from the research community. The relevant papers we investigated in this article can be accessed at https://github.com/ruthless-man/Awesome-Learn-from-Model
|
|
114
|
+
> Foundation models (FM) have demonstrated remarkable performance across a wide range of tasks (especially in the fields of natural language processing and computer vision), primarily attributed to their ability to comprehend instructions and access extensive, high-quality data. This not only showcases their current effectiveness but also sets a promising trajectory towards the development of artificial general intelligence. Unfortunately, due to multiple constraints, the raw data of the model used for large model training are often inaccessible, so the use of end-to-end models for downstream tasks has become a new research trend, which we call Learn From Model (LFM) in this article. LFM focuses on the research, modification, and design of FM based on the model interface, so as to better understand the model structure and weights (in a black box environment), and to generalize the model to downstream tasks. The study of LFM techniques can be broadly categorized into five major areas: model tuning, model distillation, model reuse, meta learning and model editing. Each category encompasses a repertoire of methods and strategies that aim to enhance the capabilities and performance of FM. This paper gives a comprehensive review of the current methods based on FM from the perspective of LFM, in order to help readers better understand the current research status and ideas. To conclude, we summarize the survey by highlighting several critical areas for future exploration and addressing open issues that require further attention from the research community. The relevant papers we investigated in this article can be accessed at <https://github.com/ruthless-man/Awesome-Learn-from-Model>.
|
|
103
115
|
</details>
|
|
104
116
|
|
|
105
117
|
<details>
|
|
106
118
|
<summary>Li Shen, Anke Tang, Enneng Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024. https://github.com/EnnengYang/Efficient-WEMoE</summary>
|
|
107
|
-
|
|
119
|
+
|
|
108
120
|
<img width="1018" alt="image" src="https://github.com/user-attachments/assets/b7e1279e-87fc-4016-8867-1bff7700e271">
|
|
109
121
|
|
|
110
122
|
</details>
|
|
@@ -130,7 +142,7 @@ Install from PyPI:
|
|
|
130
142
|
pip install fusion-bench
|
|
131
143
|
```
|
|
132
144
|
|
|
133
|
-
or install the latest version in development from
|
|
145
|
+
or install the latest version in development from the GitHub repository
|
|
134
146
|
|
|
135
147
|
```bash
|
|
136
148
|
git clone https://github.com/tanganke/fusion_bench.git
|
|
@@ -147,7 +159,6 @@ pip install -e . # install the package in editable mode
|
|
|
147
159
|
|
|
148
160
|
[](https://doi.org/10.5281/zenodo.10256836)
|
|
149
161
|
|
|
150
|
-
|
|
151
162
|
```bash
|
|
152
163
|
pip install "fusion-bench[lm-eval-harness]"
|
|
153
164
|
```
|
|
@@ -197,8 +208,8 @@ The project is structured as follows:
|
|
|
197
208
|
|
|
198
209
|
## A Unified Command Line Interface
|
|
199
210
|
|
|
200
|
-
The `fusion_bench` command-line interface is a powerful tool for researchers and practitioners in the field of model fusion. It provides a streamlined way to experiment with various fusion algorithms, model combinations, and evaluation tasks.
|
|
201
|
-
By leveraging Hydra's configuration management, fusion_bench offers flexibility in setting up experiments and reproducibility in results.
|
|
211
|
+
The `fusion_bench` command-line interface is a powerful tool for researchers and practitioners in the field of model fusion. It provides a streamlined way to experiment with various fusion algorithms, model combinations, and evaluation tasks.
|
|
212
|
+
By leveraging Hydra's configuration management, fusion_bench offers flexibility in setting up experiments and reproducibility in results.
|
|
202
213
|
The CLI's design allows for easy extension to new fusion methods, model types, and tasks, making it a versatile platform for advancing research in model fusion techniques.
|
|
203
214
|
|
|
204
215
|
Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_bench/) for more information.
|
|
@@ -237,7 +248,7 @@ class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
|
237
248
|
)
|
|
238
249
|
```
|
|
239
250
|
|
|
240
|
-
A corresponding configuration file should be created to specify the class and hyperparameters of the algorithm.
|
|
251
|
+
A corresponding configuration file should be created to specify the class and hyperparameters of the algorithm.
|
|
241
252
|
Here we assume the configuration file is placed at `config/method/your_algorithm_config.yaml`.
|
|
242
253
|
|
|
243
254
|
> [!NOTE]
|
|
@@ -272,7 +283,7 @@ Click on [<kbd>Use this template</kbd>](https://github.com/fusion-bench/fusion-b
|
|
|
272
283
|
|
|
273
284
|
### FusionBench Command Generator WebUI (for v0.1.x)
|
|
274
285
|
|
|
275
|
-
FusionBench Command Generator is a user-friendly web interface for generating FusionBench commands based on configuration files.
|
|
286
|
+
FusionBench Command Generator is a user-friendly web interface for generating FusionBench commands based on configuration files.
|
|
276
287
|
It provides an interactive way to select and customize FusionBench configurations, making it easier to run experiments with different settings.
|
|
277
288
|
[Read more here](https://tanganke.github.io/fusion_bench/cli/fusion_bench_webui/).
|
|
278
289
|
|
|
@@ -283,18 +294,14 @@ It provides an interactive way to select and customize FusionBench configuration
|
|
|
283
294
|
If you find this benchmark useful, please consider citing our work:
|
|
284
295
|
|
|
285
296
|
```bibtex
|
|
286
|
-
@
|
|
287
|
-
title
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
year
|
|
291
|
-
month = jun,
|
|
292
|
-
number = {arXiv:2406.03280},
|
|
293
|
-
eprint = {2406.03280},
|
|
294
|
-
publisher = {arXiv},
|
|
295
|
-
url = {http://arxiv.org/abs/2406.03280},
|
|
296
|
-
archiveprefix = {arxiv},
|
|
297
|
-
langid = {english},
|
|
298
|
-
keywords = {Computer Science - Artificial Intelligence,Computer Science - Computation and Language,Computer Science - Machine Learning}
|
|
297
|
+
@article{tang2024fusionbench,
|
|
298
|
+
title={Fusionbench: A comprehensive benchmark of deep model fusion},
|
|
299
|
+
author={Tang, Anke and Shen, Li and Luo, Yong and Hu, Han and Du, Bo and Tao, Dacheng},
|
|
300
|
+
journal={arXiv preprint arXiv:2406.03280},
|
|
301
|
+
year={2024}
|
|
299
302
|
}
|
|
300
303
|
```
|
|
304
|
+
|
|
305
|
+
## Star History
|
|
306
|
+
|
|
307
|
+
[](https://www.star-history.com/#tanganke/fusion_bench&Date)
|