fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -3
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +2 -3
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +5 -9
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +5 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -4,13 +4,13 @@ This is the dummy task pool that is used for debugging purposes.
|
|
|
4
4
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
|
|
7
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from torch import nn
|
|
8
9
|
|
|
9
10
|
from fusion_bench.models.separate_io import separate_save
|
|
10
11
|
from fusion_bench.taskpool.base_pool import BaseTaskPool
|
|
11
12
|
from fusion_bench.utils import timeit_context
|
|
12
13
|
from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
13
|
-
from lightning.pytorch.utilities import rank_zero_only
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def get_model_summary(model: nn.Module) -> dict:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import List, Literal, Optional, Union
|
|
3
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
import lightning.fabric
|
|
6
6
|
import lm_eval
|
|
@@ -12,7 +12,6 @@ from fusion_bench import BaseTaskPool
|
|
|
12
12
|
from fusion_bench.mixins import LightningFabricMixin
|
|
13
13
|
from fusion_bench.utils.strenum import _version
|
|
14
14
|
|
|
15
|
-
|
|
16
15
|
log = logging.getLogger(__name__)
|
|
17
16
|
|
|
18
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, Dict, List
|
|
3
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
|
|
@@ -79,7 +79,9 @@ class CLIPTemplateFactory:
|
|
|
79
79
|
}
|
|
80
80
|
|
|
81
81
|
@staticmethod
|
|
82
|
-
def get_classnames_and_templates(
|
|
82
|
+
def get_classnames_and_templates(
|
|
83
|
+
dataset_name: str,
|
|
84
|
+
) -> Tuple[List[str], List[Callable]]:
|
|
83
85
|
"""
|
|
84
86
|
Retrieves class names and templates for the specified dataset.
|
|
85
87
|
|
|
@@ -169,7 +171,7 @@ class CLIPTemplateFactory:
|
|
|
169
171
|
CLIPTemplateFactory._dataset_mapping[dataset_name] = dataset_info
|
|
170
172
|
|
|
171
173
|
@staticmethod
|
|
172
|
-
def get_available_datasets():
|
|
174
|
+
def get_available_datasets() -> List[str]:
|
|
173
175
|
"""
|
|
174
176
|
Get a list of all available dataset names.
|
|
175
177
|
|
|
@@ -179,5 +181,5 @@ class CLIPTemplateFactory:
|
|
|
179
181
|
return list(CLIPTemplateFactory._dataset_mapping.keys())
|
|
180
182
|
|
|
181
183
|
|
|
182
|
-
def get_classnames_and_templates(dataset_name: str):
|
|
184
|
+
def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
|
|
183
185
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -7,7 +7,12 @@ from .cache_utils import *
|
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
9
|
from .fabric import seed_everything_by_time
|
|
10
|
-
from .instantiate_utils import
|
|
10
|
+
from .instantiate_utils import (
|
|
11
|
+
instantiate,
|
|
12
|
+
is_instantiable,
|
|
13
|
+
set_print_function_call,
|
|
14
|
+
set_print_function_call_permeanent,
|
|
15
|
+
)
|
|
11
16
|
from .json import load_from_json, save_to_json
|
|
12
17
|
from .lazy_state_dict import LazyStateDict
|
|
13
18
|
from .misc import *
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import List, Optional, Union
|
|
4
5
|
|
|
@@ -12,7 +13,7 @@ from transformers.utils import (
|
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
|
-
"
|
|
16
|
+
"clear_cuda_cache",
|
|
16
17
|
"to_device",
|
|
17
18
|
"num_devices",
|
|
18
19
|
"get_device",
|
|
@@ -21,10 +22,19 @@ __all__ = [
|
|
|
21
22
|
"get_device_capabilities",
|
|
22
23
|
]
|
|
23
24
|
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
24
26
|
|
|
25
|
-
|
|
27
|
+
|
|
28
|
+
def clear_cuda_cache():
|
|
29
|
+
"""
|
|
30
|
+
Clears the CUDA memory cache to free up GPU memory.
|
|
31
|
+
Works only if CUDA is available.
|
|
32
|
+
"""
|
|
26
33
|
gc.collect()
|
|
27
|
-
torch.cuda.
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
torch.cuda.empty_cache()
|
|
36
|
+
else:
|
|
37
|
+
log.warning("CUDA is not available. No cache to clear.")
|
|
28
38
|
|
|
29
39
|
|
|
30
40
|
def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
@@ -75,7 +85,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
|
75
85
|
Return the number of devices.
|
|
76
86
|
|
|
77
87
|
Args:
|
|
78
|
-
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3]
|
|
88
|
+
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
|
|
79
89
|
|
|
80
90
|
Returns:
|
|
81
91
|
The number of devices.
|
|
@@ -28,7 +28,7 @@ PRINT_FUNCTION_CALL_FUNC = print
|
|
|
28
28
|
Function to be used for printing function calls.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
CATCH_EXCEPTION =
|
|
31
|
+
CATCH_EXCEPTION = False
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
@contextmanager
|
|
@@ -41,10 +41,12 @@ def set_print_function_call(value: bool):
|
|
|
41
41
|
finally:
|
|
42
42
|
PRINT_FUNCTION_CALL = old_value
|
|
43
43
|
|
|
44
|
+
|
|
44
45
|
def set_print_function_call_permeanent(value: bool):
|
|
45
46
|
global PRINT_FUNCTION_CALL
|
|
46
47
|
PRINT_FUNCTION_CALL = value
|
|
47
48
|
|
|
49
|
+
|
|
48
50
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
49
51
|
if OmegaConf.is_dict(config):
|
|
50
52
|
return "_target_" in config
|
fusion_bench/utils/modelscope.py
CHANGED
|
@@ -2,27 +2,37 @@ import os
|
|
|
2
2
|
from typing import Literal, Optional
|
|
3
3
|
|
|
4
4
|
from datasets import load_dataset as datasets_load_dataset
|
|
5
|
+
|
|
5
6
|
from fusion_bench.utils import validate_and_suggest_corrections
|
|
6
7
|
|
|
7
8
|
try:
|
|
9
|
+
from modelscope import dataset_file_download as modelscope_dataset_file_download
|
|
10
|
+
from modelscope import model_file_download as modelscope_model_file_download
|
|
8
11
|
from modelscope import snapshot_download as modelscope_snapshot_download
|
|
12
|
+
|
|
9
13
|
except ImportError:
|
|
10
14
|
|
|
11
|
-
def
|
|
15
|
+
def _raise_modelscope_not_installed_error(*args, **kwargs):
|
|
12
16
|
raise ImportError(
|
|
13
17
|
"ModelScope is not installed. Please install it using `pip install modelscope` to use ModelScope models."
|
|
14
18
|
)
|
|
15
19
|
|
|
20
|
+
modelscope_snapshot_download = _raise_modelscope_not_installed_error
|
|
21
|
+
modelscope_model_file_download = _raise_modelscope_not_installed_error
|
|
22
|
+
modelscope_dataset_file_download = _raise_modelscope_not_installed_error
|
|
16
23
|
|
|
17
24
|
try:
|
|
25
|
+
from huggingface_hub import hf_hub_download
|
|
18
26
|
from huggingface_hub import snapshot_download as huggingface_snapshot_download
|
|
19
27
|
except ImportError:
|
|
20
28
|
|
|
21
|
-
def
|
|
29
|
+
def _raise_hugggingface_not_installed_error(*args, **kwargs):
|
|
22
30
|
raise ImportError(
|
|
23
31
|
"Hugging Face Hub is not installed. Please install it using `pip install huggingface_hub` to use Hugging Face models."
|
|
24
32
|
)
|
|
25
33
|
|
|
34
|
+
huggingface_snapshot_download = _raise_hugggingface_not_installed_error
|
|
35
|
+
hf_hub_download = _raise_hugggingface_not_installed_error
|
|
26
36
|
|
|
27
37
|
__all__ = [
|
|
28
38
|
"load_dataset",
|
|
@@ -32,6 +42,12 @@ __all__ = [
|
|
|
32
42
|
AVAILABLE_PLATFORMS = ["hf", "huggingface", "modelscope"]
|
|
33
43
|
|
|
34
44
|
|
|
45
|
+
def _raise_unknown_platform_error(platform: str):
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
35
51
|
def load_dataset(
|
|
36
52
|
name: str,
|
|
37
53
|
split: str = "train",
|
|
@@ -55,9 +71,7 @@ def load_dataset(
|
|
|
55
71
|
dataset_dir = modelscope_snapshot_download(name, repo_type="dataset")
|
|
56
72
|
return datasets_load_dataset(dataset_dir, split=split)
|
|
57
73
|
else:
|
|
58
|
-
|
|
59
|
-
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
60
|
-
)
|
|
74
|
+
_raise_unknown_platform_error(platform)
|
|
61
75
|
|
|
62
76
|
|
|
63
77
|
def resolve_repo_path(
|
|
@@ -138,9 +152,114 @@ def resolve_repo_path(
|
|
|
138
152
|
repo_id=repo_id, repo_type=repo_type, **kwargs
|
|
139
153
|
)
|
|
140
154
|
else:
|
|
141
|
-
|
|
142
|
-
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
143
|
-
)
|
|
155
|
+
_raise_unknown_platform_error(platform)
|
|
144
156
|
return local_path
|
|
145
157
|
except Exception as e:
|
|
146
158
|
raise FileNotFoundError(f"Could not resolve checkpoint: {repo_id}. Error: {e}")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def resolve_file_path(
|
|
162
|
+
repo_id: str,
|
|
163
|
+
filename: str,
|
|
164
|
+
repo_type: Literal["model", "dataset"] = "model",
|
|
165
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
166
|
+
**kwargs,
|
|
167
|
+
) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Resolve and download a specific file from a repository across multiple platforms.
|
|
170
|
+
|
|
171
|
+
This function downloads a specific file from repositories hosted on various platforms
|
|
172
|
+
including local paths, Hugging Face Hub, and ModelScope. It handles platform-specific
|
|
173
|
+
URL prefixes and automatically determines the appropriate download method.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
repo_id (str): Repository identifier. Can be:
|
|
177
|
+
- Local directory path (file will be joined with this path if it exists)
|
|
178
|
+
- Hugging Face model/dataset ID (e.g., "bert-base-uncased")
|
|
179
|
+
- ModelScope model/dataset ID
|
|
180
|
+
- URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
|
|
181
|
+
The prefix will override the platform argument.
|
|
182
|
+
filename (str): The specific file to download from the repository.
|
|
183
|
+
repo_type (Literal["model", "dataset"], optional): Type of repository.
|
|
184
|
+
Defaults to "model". Used for ModelScope platform to determine the
|
|
185
|
+
correct download function.
|
|
186
|
+
platform (Literal["hf", "huggingface", "modelscope"], optional):
|
|
187
|
+
Platform to download from. Defaults to "hf". Options:
|
|
188
|
+
- "hf" or "huggingface": Hugging Face Hub
|
|
189
|
+
- "modelscope": ModelScope platform
|
|
190
|
+
**kwargs: Additional arguments passed to the underlying download functions
|
|
191
|
+
(e.g., cache_dir, force_download, use_auth_token).
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
str: Local path to the downloaded file.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If an unsupported repo_type is specified for ModelScope platform.
|
|
198
|
+
ImportError: If required dependencies for the specified platform are not installed.
|
|
199
|
+
FileNotFoundError: If the file cannot be found or downloaded.
|
|
200
|
+
|
|
201
|
+
Examples:
|
|
202
|
+
>>> # Download config.json from a Hugging Face model
|
|
203
|
+
>>> resolve_file_path("bert-base-uncased", "config.json")
|
|
204
|
+
"/home/user/.cache/huggingface/hub/models--bert-base-uncased/.../config.json"
|
|
205
|
+
|
|
206
|
+
>>> # Download from ModelScope
|
|
207
|
+
>>> resolve_file_path(
|
|
208
|
+
... "damo/nlp_bert_backbone_base_std",
|
|
209
|
+
... "pytorch_model.bin",
|
|
210
|
+
... platform="modelscope"
|
|
211
|
+
... )
|
|
212
|
+
"/home/user/.cache/modelscope/hub/.../pytorch_model.bin"
|
|
213
|
+
|
|
214
|
+
>>> # Local file path
|
|
215
|
+
>>> resolve_file_path("/path/to/local/model", "config.json")
|
|
216
|
+
"/path/to/local/model/config.json"
|
|
217
|
+
|
|
218
|
+
>>> # URL-prefixed repository
|
|
219
|
+
>>> resolve_file_path("hf://microsoft/DialoGPT-medium", "config.json")
|
|
220
|
+
"/home/user/.cache/huggingface/hub/.../config.json"
|
|
221
|
+
|
|
222
|
+
>>> # Download dataset file from ModelScope
|
|
223
|
+
>>> resolve_file_path(
|
|
224
|
+
... "DAMO_NLP/jd",
|
|
225
|
+
... "train.json",
|
|
226
|
+
... repo_type="dataset",
|
|
227
|
+
... platform="modelscope"
|
|
228
|
+
... )
|
|
229
|
+
"/home/user/.cache/modelscope/datasets/.../train.json"
|
|
230
|
+
"""
|
|
231
|
+
# If it's a HuggingFace Hub model id, download snapshot
|
|
232
|
+
if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
|
|
233
|
+
repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
|
|
234
|
+
platform = "hf"
|
|
235
|
+
# If it's a ModelScope model id, download snapshot
|
|
236
|
+
elif repo_id.startswith("modelscope://"):
|
|
237
|
+
repo_id = repo_id.replace("modelscope://", "")
|
|
238
|
+
platform = "modelscope"
|
|
239
|
+
|
|
240
|
+
# If it's a local file or directory, return as is
|
|
241
|
+
if os.path.exists(repo_id):
|
|
242
|
+
return os.path.join(repo_id, filename)
|
|
243
|
+
|
|
244
|
+
if platform in ["hf", "huggingface"]:
|
|
245
|
+
return hf_hub_download(
|
|
246
|
+
repo_id=repo_id,
|
|
247
|
+
filename=filename,
|
|
248
|
+
repo_type=repo_type,
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
elif platform == "modelscope":
|
|
252
|
+
if repo_type == "model":
|
|
253
|
+
return modelscope_model_file_download(
|
|
254
|
+
model_id=repo_id, file_path=filename, **kwargs
|
|
255
|
+
)
|
|
256
|
+
elif repo_type == "dataset":
|
|
257
|
+
return modelscope_dataset_file_download(
|
|
258
|
+
dataset_id=repo_id, file_path=filename, **kwargs
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"Unsupported repo_type: {repo_type}. Supported types are 'model' and 'dataset'."
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
_raise_unknown_platform_error(platform)
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from collections import OrderedDict
|
|
3
|
-
from typing import List, Mapping, Optional, Union
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -83,7 +83,7 @@ def vector_to_state_dict(
|
|
|
83
83
|
vector: torch.Tensor,
|
|
84
84
|
state_dict: Union[StateDictType, nn.Module],
|
|
85
85
|
remove_keys: Optional[List[str]] = None,
|
|
86
|
-
):
|
|
86
|
+
) -> Dict[str, torch.Tensor]:
|
|
87
87
|
"""
|
|
88
88
|
Convert a vector to a state dictionary.
|
|
89
89
|
|
fusion_bench/utils/rich_utils.py
CHANGED
|
@@ -44,7 +44,7 @@ def state_dicts_check_keys(state_dicts: List[StateDictType]):
|
|
|
44
44
|
assert keys == set(state_dict.keys()), "keys of state_dicts are not equal"
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def num_params_of_state_dict(state_dict: StateDictType):
|
|
47
|
+
def num_params_of_state_dict(state_dict: StateDictType) -> int:
|
|
48
48
|
"""
|
|
49
49
|
Returns the number of parameters in a state dict.
|
|
50
50
|
|
|
@@ -57,7 +57,7 @@ def num_params_of_state_dict(state_dict: StateDictType):
|
|
|
57
57
|
return sum([state_dict[key].numel() for key in state_dict])
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def state_dict_flatten(state_dict: Dict[str, Tensor]):
|
|
60
|
+
def state_dict_flatten(state_dict: Dict[str, Tensor]) -> Tensor:
|
|
61
61
|
"""
|
|
62
62
|
Flattens a state dict.
|
|
63
63
|
|
|
@@ -73,7 +73,7 @@ def state_dict_flatten(state_dict: Dict[str, Tensor]):
|
|
|
73
73
|
return torch.cat(flattened_state_dict)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
-
def state_dict_avg(state_dicts: List[StateDictType]):
|
|
76
|
+
def state_dict_avg(state_dicts: List[StateDictType]) -> StateDictType:
|
|
77
77
|
"""
|
|
78
78
|
Returns the average of a list of state dicts.
|
|
79
79
|
|
|
@@ -100,7 +100,7 @@ def state_dict_avg(state_dicts: List[StateDictType]):
|
|
|
100
100
|
|
|
101
101
|
def state_dict_sub(
|
|
102
102
|
a: StateDictType, b: StateDictType, strict: bool = True, device=None
|
|
103
|
-
):
|
|
103
|
+
) -> StateDictType:
|
|
104
104
|
"""
|
|
105
105
|
Returns the difference between two state dicts `a-b`.
|
|
106
106
|
|
|
@@ -130,7 +130,7 @@ def state_dict_add(
|
|
|
130
130
|
strict: bool = True,
|
|
131
131
|
device=None,
|
|
132
132
|
show_pbar: bool = False,
|
|
133
|
-
):
|
|
133
|
+
) -> StateDictType:
|
|
134
134
|
"""
|
|
135
135
|
Returns the sum of two state dicts.
|
|
136
136
|
|
|
@@ -156,14 +156,14 @@ def state_dict_add(
|
|
|
156
156
|
return ans
|
|
157
157
|
|
|
158
158
|
|
|
159
|
-
def state_dict_add_scalar(a: StateDictType, scalar: Number):
|
|
159
|
+
def state_dict_add_scalar(a: StateDictType, scalar: Number) -> StateDictType:
|
|
160
160
|
ans = OrderedDict()
|
|
161
161
|
for key in a:
|
|
162
162
|
ans[key] = a[key] + scalar
|
|
163
163
|
return ans
|
|
164
164
|
|
|
165
165
|
|
|
166
|
-
def state_dict_mul(state_dict: StateDictType, scalar: float):
|
|
166
|
+
def state_dict_mul(state_dict: StateDictType, scalar: float) -> StateDictType:
|
|
167
167
|
"""
|
|
168
168
|
Returns the product of a state dict and a scalar.
|
|
169
169
|
|
|
@@ -180,7 +180,9 @@ def state_dict_mul(state_dict: StateDictType, scalar: float):
|
|
|
180
180
|
return diff
|
|
181
181
|
|
|
182
182
|
|
|
183
|
-
def state_dict_div(
|
|
183
|
+
def state_dict_div(
|
|
184
|
+
state_dict: StateDictType, scalar: float, show_pbar: bool = False
|
|
185
|
+
) -> StateDictType:
|
|
184
186
|
"""
|
|
185
187
|
Returns the division of a state dict by a scalar.
|
|
186
188
|
|
|
@@ -197,16 +199,16 @@ def state_dict_div(state_dict: StateDictType, scalar: float, show_pbar: bool = F
|
|
|
197
199
|
return diff
|
|
198
200
|
|
|
199
201
|
|
|
200
|
-
def state_dict_power(state_dict:
|
|
202
|
+
def state_dict_power(state_dict: StateDictType, p: float) -> StateDictType:
|
|
201
203
|
"""
|
|
202
204
|
Returns the power of a state dict.
|
|
203
205
|
|
|
204
206
|
Args:
|
|
205
|
-
state_dict (
|
|
207
|
+
state_dict (StateDictType): The state dict to be powered.
|
|
206
208
|
p (float): The power to raise the state dict to.
|
|
207
209
|
|
|
208
210
|
Returns:
|
|
209
|
-
|
|
211
|
+
StateDictType: The powered state dict.
|
|
210
212
|
"""
|
|
211
213
|
powered_state_dict = {}
|
|
212
214
|
for key in state_dict:
|
|
@@ -215,17 +217,17 @@ def state_dict_power(state_dict: Dict[str, Tensor], p: float):
|
|
|
215
217
|
|
|
216
218
|
|
|
217
219
|
def state_dict_interpolation(
|
|
218
|
-
state_dicts: List[
|
|
219
|
-
):
|
|
220
|
+
state_dicts: List[StateDictType], scalars: List[float]
|
|
221
|
+
) -> StateDictType:
|
|
220
222
|
"""
|
|
221
223
|
Interpolates between a list of state dicts using a list of scalars.
|
|
222
224
|
|
|
223
225
|
Args:
|
|
224
|
-
state_dicts (List[
|
|
226
|
+
state_dicts (List[StateDictType]): The list of state dicts to interpolate between.
|
|
225
227
|
scalars (List[float]): The list of scalars to use for interpolation.
|
|
226
228
|
|
|
227
229
|
Returns:
|
|
228
|
-
|
|
230
|
+
StateDictType: The interpolated state dict.
|
|
229
231
|
"""
|
|
230
232
|
assert len(state_dicts) == len(
|
|
231
233
|
scalars
|
|
@@ -243,15 +245,15 @@ def state_dict_interpolation(
|
|
|
243
245
|
return interpolated_state_dict
|
|
244
246
|
|
|
245
247
|
|
|
246
|
-
def state_dict_sum(state_dicts: List[StateDictType]):
|
|
248
|
+
def state_dict_sum(state_dicts: List[StateDictType]) -> StateDictType:
|
|
247
249
|
"""
|
|
248
250
|
Returns the sum of a list of state dicts.
|
|
249
251
|
|
|
250
252
|
Args:
|
|
251
|
-
state_dicts (List[
|
|
253
|
+
state_dicts (List[StateDictType]): The list of state dicts to sum.
|
|
252
254
|
|
|
253
255
|
Returns:
|
|
254
|
-
|
|
256
|
+
StateDictType: The sum of the state dicts.
|
|
255
257
|
"""
|
|
256
258
|
assert len(state_dicts) > 0, "The number of state_dicts must be greater than 0"
|
|
257
259
|
assert all(
|
|
@@ -267,17 +269,17 @@ def state_dict_sum(state_dicts: List[StateDictType]):
|
|
|
267
269
|
|
|
268
270
|
|
|
269
271
|
def state_dict_weighted_sum(
|
|
270
|
-
state_dicts: List[
|
|
271
|
-
):
|
|
272
|
+
state_dicts: List[StateDictType], weights: List[float], device=None
|
|
273
|
+
) -> StateDictType:
|
|
272
274
|
"""
|
|
273
275
|
Returns the weighted sum of a list of state dicts.
|
|
274
276
|
|
|
275
277
|
Args:
|
|
276
|
-
state_dicts (List[
|
|
278
|
+
state_dicts (List[StateDictType]): The list of state dicts to interpolate between.
|
|
277
279
|
weights (List[float]): The list of weights to use for the weighted sum.
|
|
278
280
|
|
|
279
281
|
Returns:
|
|
280
|
-
|
|
282
|
+
StateDictType: The weighted sum of the state dicts.
|
|
281
283
|
"""
|
|
282
284
|
assert len(state_dicts) == len(
|
|
283
285
|
weights
|
|
@@ -302,7 +304,7 @@ def state_dict_weighted_sum(
|
|
|
302
304
|
return weighted_sum_state_dict
|
|
303
305
|
|
|
304
306
|
|
|
305
|
-
def state_dict_diff_abs(a: StateDictType, b: StateDictType):
|
|
307
|
+
def state_dict_diff_abs(a: StateDictType, b: StateDictType) -> StateDictType:
|
|
306
308
|
"""
|
|
307
309
|
Returns the per-layer abs of the difference between two state dicts.
|
|
308
310
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.21
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,13 +45,17 @@ Requires-Dist: rich
|
|
|
45
45
|
Requires-Dist: scipy
|
|
46
46
|
Requires-Dist: h5py
|
|
47
47
|
Requires-Dist: pytest
|
|
48
|
+
Requires-Dist: transformers!=4.49
|
|
49
|
+
Requires-Dist: pillow!=11.2.1
|
|
48
50
|
Provides-Extra: lm-eval-harness
|
|
49
51
|
Requires-Dist: lm-eval; extra == "lm-eval-harness"
|
|
52
|
+
Requires-Dist: immutabledict; extra == "lm-eval-harness"
|
|
53
|
+
Requires-Dist: langdetect; extra == "lm-eval-harness"
|
|
50
54
|
Dynamic: license-file
|
|
51
55
|
|
|
52
56
|
<div align='center'>
|
|
53
57
|
|
|
54
|
-
# FusionBench: A Comprehensive Benchmark/
|
|
58
|
+
# FusionBench: A Comprehensive Benchmark/Toolkit of Deep Model Fusion
|
|
55
59
|
|
|
56
60
|
[](http://arxiv.org/abs/2406.03280)
|
|
57
61
|
[](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
|
|
@@ -75,7 +79,7 @@ Projects based on FusionBench and news from the community (descending order of d
|
|
|
75
79
|
<details>
|
|
76
80
|
<summary>The-Hai Nguyen, Dang Huu-Tien, Takeshi Suzuki, and Le-Minh Nguyen. RegMean++: Enhancing Effectiveness and Generalization of Regression Mean for Model Merging. Aug, 2025. https://www.arxiv.org/abs/2508.03121</summary>
|
|
77
81
|
|
|
78
|
-
Regression Mean (RegMean), an approach that formulates model merging as a linear regression problem, aims to find the optimal weights for each linear layer in the merge model by minimizing the discrepancy in predictions between the merge and candidate models. RegMean provides a precise closed-form solution for the merging problem; therefore, it offers explainability and computational efficiency. However, RegMean merges each linear layer independently, overlooking how the features and information in the earlier layers propagate through the layers and influence the final prediction in the merge model. In this paper, we introduce RegMean++, a simple yet effective alternative to RegMean, that explicitly incorporates both intra- and cross-layer dependencies between merge models' layers into RegMean's objective. By accounting for these dependencies, RegMean++ better captures the behaviors of the merge model. Extensive experiments demonstrate that RegMean++ consistently outperforms RegMean across diverse settings, including in-domain (ID) and out-of-domain (OOD) generalization, sequential merging, large-scale tasks, and robustness under several types of distribution shifts. Furthermore, RegMean++ achieves competitive or state-of-the-art performance compared to various recent advanced model merging methods.
|
|
82
|
+
Regression Mean (RegMean), an approach that formulates model merging as a linear regression problem, aims to find the optimal weights for each linear layer in the merge model by minimizing the discrepancy in predictions between the merge and candidate models. RegMean provides a precise closed-form solution for the merging problem; therefore, it offers explainability and computational efficiency. However, RegMean merges each linear layer independently, overlooking how the features and information in the earlier layers propagate through the layers and influence the final prediction in the merge model. In this paper, we introduce RegMean++, a simple yet effective alternative to RegMean, that explicitly incorporates both intra- and cross-layer dependencies between merge models' layers into RegMean's objective. By accounting for these dependencies, RegMean++ better captures the behaviors of the merge model. Extensive experiments demonstrate that RegMean++ consistently outperforms RegMean across diverse settings, including in-domain (ID) and out-of-domain (OOD) generalization, sequential merging, large-scale tasks, and robustness under several types of distribution shifts. Furthermore, RegMean++ achieves competitive or state-of-the-art performance compared to various recent advanced model merging methods.
|
|
79
83
|
|
|
80
84
|
<img width="1000" alt="image" src="docs/algorithms/images/regmean_vs_regmean_plusplus.png">
|
|
81
85
|
</details>
|
|
@@ -89,7 +93,7 @@ Model merging has emerged as a promising approach for multi-task learning (MTL),
|
|
|
89
93
|
<details>
|
|
90
94
|
<summary>Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. Feb 2025. https://arxiv.org/abs/2502.04959</summary>
|
|
91
95
|
|
|
92
|
-
Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
|
|
96
|
+
Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
|
|
93
97
|
</details>
|
|
94
98
|
|
|
95
99
|
<details>
|
|
@@ -107,12 +111,12 @@ Merging multiple expert models offers a promising approach for performing multi-
|
|
|
107
111
|
<details>
|
|
108
112
|
<summary>Hongling Zheng, Li Shen, Anke Tang, Yong Luo et al. Learn From Model Beyond Fine-Tuning: A Survey. Nature Machine Intelligence. Jan, 2025. https://www.nature.com/articles/s42256-024-00961-0</summary>
|
|
109
113
|
|
|
110
|
-
> Foundation models (FM) have demonstrated remarkable performance across a wide range of tasks (especially in the fields of natural language processing and computer vision), primarily attributed to their ability to comprehend instructions and access extensive, high-quality data. This not only showcases their current effectiveness but also sets a promising trajectory towards the development of artificial general intelligence. Unfortunately, due to multiple constraints, the raw data of the model used for large model training are often inaccessible, so the use of end-to-end models for downstream tasks has become a new research trend, which we call Learn From Model (LFM) in this article. LFM focuses on the research, modification, and design of FM based on the model interface, so as to better understand the model structure and weights (in a black box environment), and to generalize the model to downstream tasks. The study of LFM techniques can be broadly categorized into five major areas: model tuning, model distillation, model reuse, meta learning and model editing. Each category encompasses a repertoire of methods and strategies that aim to enhance the capabilities and performance of FM. This paper gives a comprehensive review of the current methods based on FM from the perspective of LFM, in order to help readers better understand the current research status and ideas. To conclude, we summarize the survey by highlighting several critical areas for future exploration and addressing open issues that require further attention from the research community. The relevant papers we investigated in this article can be accessed at https://github.com/ruthless-man/Awesome-Learn-from-Model
|
|
114
|
+
> Foundation models (FM) have demonstrated remarkable performance across a wide range of tasks (especially in the fields of natural language processing and computer vision), primarily attributed to their ability to comprehend instructions and access extensive, high-quality data. This not only showcases their current effectiveness but also sets a promising trajectory towards the development of artificial general intelligence. Unfortunately, due to multiple constraints, the raw data of the model used for large model training are often inaccessible, so the use of end-to-end models for downstream tasks has become a new research trend, which we call Learn From Model (LFM) in this article. LFM focuses on the research, modification, and design of FM based on the model interface, so as to better understand the model structure and weights (in a black box environment), and to generalize the model to downstream tasks. The study of LFM techniques can be broadly categorized into five major areas: model tuning, model distillation, model reuse, meta learning and model editing. Each category encompasses a repertoire of methods and strategies that aim to enhance the capabilities and performance of FM. This paper gives a comprehensive review of the current methods based on FM from the perspective of LFM, in order to help readers better understand the current research status and ideas. To conclude, we summarize the survey by highlighting several critical areas for future exploration and addressing open issues that require further attention from the research community. The relevant papers we investigated in this article can be accessed at <https://github.com/ruthless-man/Awesome-Learn-from-Model>.
|
|
111
115
|
</details>
|
|
112
116
|
|
|
113
117
|
<details>
|
|
114
118
|
<summary>Li Shen, Anke Tang, Enneng Yang et al. Efficient and Effective Weight-Ensembling Mixture of Experts for Multi-Task Model Merging. Oct, 2024. https://github.com/EnnengYang/Efficient-WEMoE</summary>
|
|
115
|
-
|
|
119
|
+
|
|
116
120
|
<img width="1018" alt="image" src="https://github.com/user-attachments/assets/b7e1279e-87fc-4016-8867-1bff7700e271">
|
|
117
121
|
|
|
118
122
|
</details>
|
|
@@ -138,7 +142,7 @@ Install from PyPI:
|
|
|
138
142
|
pip install fusion-bench
|
|
139
143
|
```
|
|
140
144
|
|
|
141
|
-
or install the latest version in development from
|
|
145
|
+
or install the latest version in development from the GitHub repository
|
|
142
146
|
|
|
143
147
|
```bash
|
|
144
148
|
git clone https://github.com/tanganke/fusion_bench.git
|
|
@@ -155,7 +159,6 @@ pip install -e . # install the package in editable mode
|
|
|
155
159
|
|
|
156
160
|
[](https://doi.org/10.5281/zenodo.10256836)
|
|
157
161
|
|
|
158
|
-
|
|
159
162
|
```bash
|
|
160
163
|
pip install "fusion-bench[lm-eval-harness]"
|
|
161
164
|
```
|
|
@@ -205,8 +208,8 @@ The project is structured as follows:
|
|
|
205
208
|
|
|
206
209
|
## A Unified Command Line Interface
|
|
207
210
|
|
|
208
|
-
The `fusion_bench` command-line interface is a powerful tool for researchers and practitioners in the field of model fusion. It provides a streamlined way to experiment with various fusion algorithms, model combinations, and evaluation tasks.
|
|
209
|
-
By leveraging Hydra's configuration management, fusion_bench offers flexibility in setting up experiments and reproducibility in results.
|
|
211
|
+
The `fusion_bench` command-line interface is a powerful tool for researchers and practitioners in the field of model fusion. It provides a streamlined way to experiment with various fusion algorithms, model combinations, and evaluation tasks.
|
|
212
|
+
By leveraging Hydra's configuration management, fusion_bench offers flexibility in setting up experiments and reproducibility in results.
|
|
210
213
|
The CLI's design allows for easy extension to new fusion methods, model types, and tasks, making it a versatile platform for advancing research in model fusion techniques.
|
|
211
214
|
|
|
212
215
|
Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_bench/) for more information.
|
|
@@ -245,7 +248,7 @@ class DerivedModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
|
245
248
|
)
|
|
246
249
|
```
|
|
247
250
|
|
|
248
|
-
A corresponding configuration file should be created to specify the class and hyperparameters of the algorithm.
|
|
251
|
+
A corresponding configuration file should be created to specify the class and hyperparameters of the algorithm.
|
|
249
252
|
Here we assume the configuration file is placed at `config/method/your_algorithm_config.yaml`.
|
|
250
253
|
|
|
251
254
|
> [!NOTE]
|
|
@@ -280,7 +283,7 @@ Click on [<kbd>Use this template</kbd>](https://github.com/fusion-bench/fusion-b
|
|
|
280
283
|
|
|
281
284
|
### FusionBench Command Generator WebUI (for v0.1.x)
|
|
282
285
|
|
|
283
|
-
FusionBench Command Generator is a user-friendly web interface for generating FusionBench commands based on configuration files.
|
|
286
|
+
FusionBench Command Generator is a user-friendly web interface for generating FusionBench commands based on configuration files.
|
|
284
287
|
It provides an interactive way to select and customize FusionBench configurations, making it easier to run experiments with different settings.
|
|
285
288
|
[Read more here](https://tanganke.github.io/fusion_bench/cli/fusion_bench_webui/).
|
|
286
289
|
|
|
@@ -291,18 +294,14 @@ It provides an interactive way to select and customize FusionBench configuration
|
|
|
291
294
|
If you find this benchmark useful, please consider citing our work:
|
|
292
295
|
|
|
293
296
|
```bibtex
|
|
294
|
-
@
|
|
295
|
-
title
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
year
|
|
299
|
-
month = jun,
|
|
300
|
-
number = {arXiv:2406.03280},
|
|
301
|
-
eprint = {2406.03280},
|
|
302
|
-
publisher = {arXiv},
|
|
303
|
-
url = {http://arxiv.org/abs/2406.03280},
|
|
304
|
-
archiveprefix = {arxiv},
|
|
305
|
-
langid = {english},
|
|
306
|
-
keywords = {Computer Science - Artificial Intelligence,Computer Science - Computation and Language,Computer Science - Machine Learning}
|
|
297
|
+
@article{tang2024fusionbench,
|
|
298
|
+
title={Fusionbench: A comprehensive benchmark of deep model fusion},
|
|
299
|
+
author={Tang, Anke and Shen, Li and Luo, Yong and Hu, Han and Du, Bo and Tao, Dacheng},
|
|
300
|
+
journal={arXiv preprint arXiv:2406.03280},
|
|
301
|
+
year={2024}
|
|
307
302
|
}
|
|
308
303
|
```
|
|
304
|
+
|
|
305
|
+
## Star History
|
|
306
|
+
|
|
307
|
+
[](https://www.star-history.com/#tanganke/fusion_bench&Date)
|