fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -19,8 +19,28 @@ from . import (
|
|
|
19
19
|
tasks,
|
|
20
20
|
utils,
|
|
21
21
|
)
|
|
22
|
+
from .constants import RuntimeConstants
|
|
22
23
|
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
24
|
+
from .mixins import auto_register_config
|
|
23
25
|
from .modelpool import BaseModelPool
|
|
24
|
-
from .models import
|
|
26
|
+
from .models import (
|
|
27
|
+
create_default_model_card,
|
|
28
|
+
load_model_card_template,
|
|
29
|
+
save_pretrained_with_remote_code,
|
|
30
|
+
separate_io,
|
|
31
|
+
)
|
|
32
|
+
from .programs import BaseHydraProgram
|
|
25
33
|
from .taskpool import BaseTaskPool
|
|
26
|
-
from .utils import
|
|
34
|
+
from .utils import (
|
|
35
|
+
cache_with_joblib,
|
|
36
|
+
get_rankzero_logger,
|
|
37
|
+
import_object,
|
|
38
|
+
instantiate,
|
|
39
|
+
parse_dtype,
|
|
40
|
+
print_parameters,
|
|
41
|
+
seed_everything_by_time,
|
|
42
|
+
set_default_cache_dir,
|
|
43
|
+
set_print_function_call,
|
|
44
|
+
set_print_function_call_permeanent,
|
|
45
|
+
timeit_context,
|
|
46
|
+
)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
|
|
6
|
+
from fusion_bench.programs import BaseHydraProgram
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GreetingProgram(BaseHydraProgram):
|
|
12
|
+
"""
|
|
13
|
+
A simple program that greets users with a custom message.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
_config_mapping = BaseHydraProgram._config_mapping | {
|
|
17
|
+
"message": "message",
|
|
18
|
+
"name": "name",
|
|
19
|
+
"repeat_count": "repeat_count",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
message: str = "Hello",
|
|
25
|
+
name: str = "World",
|
|
26
|
+
repeat_count: int = 1,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
self.message = message
|
|
30
|
+
self.name = name
|
|
31
|
+
self.repeat_count = repeat_count
|
|
32
|
+
super().__init__(**kwargs)
|
|
33
|
+
|
|
34
|
+
def run(self):
|
|
35
|
+
"""Execute the greeting workflow."""
|
|
36
|
+
log.info("Starting greeting program")
|
|
37
|
+
|
|
38
|
+
# Create the greeting
|
|
39
|
+
greeting = f"{self.message}, {self.name}!"
|
|
40
|
+
|
|
41
|
+
# Print the greeting multiple times
|
|
42
|
+
for i in range(self.repeat_count):
|
|
43
|
+
if self.repeat_count > 1:
|
|
44
|
+
print(f"[{i+1}/{self.repeat_count}] {greeting}")
|
|
45
|
+
else:
|
|
46
|
+
print(greeting)
|
|
47
|
+
|
|
48
|
+
log.info("Greeting program completed")
|
|
49
|
+
return greeting
|
|
@@ -36,6 +36,20 @@ class ModelFusionAlgorithm(ABC):
|
|
|
36
36
|
algorithm_config = DictConfig({})
|
|
37
37
|
self.config = algorithm_config
|
|
38
38
|
|
|
39
|
+
def on_run_start(self):
|
|
40
|
+
"""
|
|
41
|
+
Hook method called at the start of the run.
|
|
42
|
+
Can be overridden by subclasses to perform initialization tasks.
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def on_run_end(self):
|
|
47
|
+
"""
|
|
48
|
+
Hook method called at the end of the run.
|
|
49
|
+
Can be overridden by subclasses to perform cleanup tasks.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
39
53
|
@abstractmethod
|
|
40
54
|
def run(self, modelpool):
|
|
41
55
|
"""
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
"Constants for CLIP Vision Model Merging"
|
|
2
|
+
|
|
2
3
|
TASK_NAMES_TA8 = [
|
|
3
4
|
"sun397",
|
|
4
5
|
"stanford-cars",
|
|
@@ -9,7 +10,23 @@ TASK_NAMES_TA8 = [
|
|
|
9
10
|
"mnist",
|
|
10
11
|
"dtd",
|
|
11
12
|
]
|
|
12
|
-
|
|
13
|
+
"The 8 tasks used in the Task Arithmetic paper."
|
|
14
|
+
TASK_NAMES_TALL8 = TASK_NAMES_TA8
|
|
15
|
+
"The 8 tasks used in the Tall Mask paper"
|
|
16
|
+
TASK_NAMES_TALL10 = TASK_NAMES_TA8 + ["oxford_flowers102", "pcam"]
|
|
17
|
+
TASK_NAMES_TALL12 = TASK_NAMES_TALL10 + [
|
|
18
|
+
"fer2013",
|
|
19
|
+
"oxford-iiit-pet",
|
|
20
|
+
]
|
|
21
|
+
TASK_NAMES_TALL14 = TASK_NAMES_TALL12 + [
|
|
22
|
+
"stl10",
|
|
23
|
+
"cifar100",
|
|
24
|
+
]
|
|
25
|
+
"The 14 tasks used in the TALL mask paper"
|
|
26
|
+
TASK_NAMES_TALL16 = TASK_NAMES_TALL14 + ["cifar10", "food101"]
|
|
27
|
+
TASK_NAMES_TALL18 = TASK_NAMES_TALL16 + ["fashion_mnist", "emnist_letters"]
|
|
28
|
+
TASK_NAMES_TALL20 = TASK_NAMES_TALL18 + ["kmnist", "rendered-sst2"]
|
|
29
|
+
"The 20 tasks used in the TALL mask paper"
|
|
13
30
|
TASK_NAMES_TA8_CAP = [
|
|
14
31
|
"SUN397",
|
|
15
32
|
"Cars",
|
|
@@ -20,3 +37,10 @@ TASK_NAMES_TA8_CAP = [
|
|
|
20
37
|
"MNIST",
|
|
21
38
|
"DTD",
|
|
22
39
|
]
|
|
40
|
+
TASK_NAMES_TALL8_CAP = TASK_NAMES_TA8_CAP
|
|
41
|
+
TASK_NAMES_TALL10_CAP = TASK_NAMES_TALL8_CAP + ["Flowers102", "PCAM"]
|
|
42
|
+
TASK_NAMES_TALL12_CAP = TASK_NAMES_TALL10_CAP + ["FER2013", "OxfordIIITPet"]
|
|
43
|
+
TASK_NAMES_TALL14_CAP = TASK_NAMES_TALL12_CAP + ["STL10", "CIFAR100"]
|
|
44
|
+
TASK_NAMES_TALL16_CAP = TASK_NAMES_TALL14_CAP + ["CIFAR10", "Food101"]
|
|
45
|
+
TASK_NAMES_TALL18_CAP = TASK_NAMES_TALL16_CAP + ["FashionMNIST", "EMNIST"]
|
|
46
|
+
TASK_NAMES_TALL20_CAP = TASK_NAMES_TALL18_CAP + ["KMNIST", "RenderedSST2"]
|
fusion_bench/constants/paths.py
CHANGED
|
@@ -7,10 +7,14 @@ log = logging.getLogger(__name__)
|
|
|
7
7
|
__all__ = ["LIBRARY_PATH", "PROJECT_ROOT_PATH", "DEFAULT_CONFIG_PATH"]
|
|
8
8
|
|
|
9
9
|
LIBRARY_PATH = Path(importlib.import_module("fusion_bench").__path__[0])
|
|
10
|
+
"""Path to the library directory."""
|
|
11
|
+
|
|
10
12
|
PROJECT_ROOT_PATH = LIBRARY_PATH.parent
|
|
13
|
+
"""Path to the project root directory."""
|
|
11
14
|
|
|
12
15
|
if (PROJECT_ROOT_PATH / "config").is_dir():
|
|
13
16
|
DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "config"
|
|
17
|
+
"""Path to the default config directory."""
|
|
14
18
|
elif (PROJECT_ROOT_PATH / "fusion_bench_config").is_dir():
|
|
15
19
|
DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "fusion_bench_config"
|
|
16
20
|
else:
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RuntimeConstants:
|
|
7
|
+
"""
|
|
8
|
+
This class holds constants related to the runtime environment of the Fusion Bench framework.
|
|
9
|
+
It includes default values for cache directories and other runtime configurations.
|
|
10
|
+
|
|
11
|
+
Implemented as a thread-safe singleton to ensure consistent runtime configuration
|
|
12
|
+
across the entire application.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
_instance: Optional["RuntimeConstants"] = None
|
|
16
|
+
_lock = threading.Lock()
|
|
17
|
+
|
|
18
|
+
def __new__(cls) -> "RuntimeConstants":
|
|
19
|
+
"""Create a new instance using singleton pattern with thread safety."""
|
|
20
|
+
with cls._lock:
|
|
21
|
+
# Double-check locking pattern
|
|
22
|
+
if cls._instance is None:
|
|
23
|
+
cls._instance = super(RuntimeConstants, cls).__new__(cls)
|
|
24
|
+
cls._instance._initialized = False
|
|
25
|
+
return cls._instance
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
"""Initialize the singleton instance only once."""
|
|
29
|
+
if not self._initialized:
|
|
30
|
+
# Add your runtime constants here
|
|
31
|
+
self._initialized = True
|
|
32
|
+
|
|
33
|
+
debug = False
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def cache_dir(self) -> Path:
|
|
37
|
+
from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
|
|
38
|
+
|
|
39
|
+
return DEFAULT_CACHE_DIR
|
|
40
|
+
|
|
41
|
+
@cache_dir.setter
|
|
42
|
+
def cache_dir(self, path: Union[str, Path]) -> None:
|
|
43
|
+
from fusion_bench.utils.cache_utils import set_default_cache_dir
|
|
44
|
+
|
|
45
|
+
set_default_cache_dir(path)
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def print_function_call(self) -> bool:
|
|
49
|
+
from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
|
|
50
|
+
|
|
51
|
+
return PRINT_FUNCTION_CALL
|
|
52
|
+
|
|
53
|
+
@print_function_call.setter
|
|
54
|
+
def print_function_call(self, enable: bool) -> None:
|
|
55
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
56
|
+
|
|
57
|
+
set_print_function_call(enable)
|
|
@@ -5,6 +5,7 @@ This module provides a class to convert a dataset whose object is a list of dict
|
|
|
5
5
|
from typing import Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
from torch.utils.data import Dataset
|
|
8
9
|
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
10
|
|
|
10
11
|
__all__ = ["CLIPDataset"]
|
|
@@ -28,7 +29,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
28
29
|
processor (CLIPProcessor): The CLIP processor used for image preprocessing.
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
|
-
def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
|
|
32
|
+
def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
|
|
32
33
|
self.dataset = dataset
|
|
33
34
|
self.processor = processor
|
|
34
35
|
|
|
@@ -16,7 +16,7 @@ from functools import partial
|
|
|
16
16
|
from pathlib import Path
|
|
17
17
|
from typing import Literal
|
|
18
18
|
|
|
19
|
-
from datasets import load_dataset, load_from_disk
|
|
19
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
20
20
|
from transformers import PreTrainedTokenizer
|
|
21
21
|
|
|
22
22
|
|
|
@@ -147,7 +147,7 @@ class TokenizedGLUE:
|
|
|
147
147
|
return glue_dataset_loaders[name]()
|
|
148
148
|
|
|
149
149
|
@cache_dataset
|
|
150
|
-
def load_mrpc_dataset(self):
|
|
150
|
+
def load_mrpc_dataset(self) -> Dataset:
|
|
151
151
|
"""
|
|
152
152
|
Load and tokenize the MRPC dataset.
|
|
153
153
|
|
|
@@ -166,7 +166,7 @@ class TokenizedGLUE:
|
|
|
166
166
|
return dataset
|
|
167
167
|
|
|
168
168
|
@cache_dataset
|
|
169
|
-
def load_rte_dataset(self):
|
|
169
|
+
def load_rte_dataset(self) -> Dataset:
|
|
170
170
|
"""
|
|
171
171
|
Load and tokenize the RTE dataset.
|
|
172
172
|
|
|
@@ -186,7 +186,7 @@ class TokenizedGLUE:
|
|
|
186
186
|
return dataset
|
|
187
187
|
|
|
188
188
|
@cache_dataset
|
|
189
|
-
def load_wnli_dataset(self):
|
|
189
|
+
def load_wnli_dataset(self) -> Dataset:
|
|
190
190
|
"""
|
|
191
191
|
Load and tokenize the WNLI dataset.
|
|
192
192
|
|
|
@@ -205,7 +205,7 @@ class TokenizedGLUE:
|
|
|
205
205
|
return dataset
|
|
206
206
|
|
|
207
207
|
@cache_dataset
|
|
208
|
-
def load_qqp_dataset(self):
|
|
208
|
+
def load_qqp_dataset(self) -> Dataset:
|
|
209
209
|
"""
|
|
210
210
|
Load and tokenize the QQP dataset.
|
|
211
211
|
|
|
@@ -224,7 +224,7 @@ class TokenizedGLUE:
|
|
|
224
224
|
return dataset
|
|
225
225
|
|
|
226
226
|
@cache_dataset
|
|
227
|
-
def load_mnli_dataset(self):
|
|
227
|
+
def load_mnli_dataset(self) -> Dataset:
|
|
228
228
|
"""
|
|
229
229
|
Load and tokenize the MNLI dataset.
|
|
230
230
|
|
|
@@ -243,7 +243,7 @@ class TokenizedGLUE:
|
|
|
243
243
|
return dataset
|
|
244
244
|
|
|
245
245
|
@cache_dataset
|
|
246
|
-
def load_cola_dataset(self):
|
|
246
|
+
def load_cola_dataset(self) -> Dataset:
|
|
247
247
|
"""
|
|
248
248
|
Load and tokenize the CoLA dataset.
|
|
249
249
|
|
|
@@ -262,7 +262,7 @@ class TokenizedGLUE:
|
|
|
262
262
|
return dataset
|
|
263
263
|
|
|
264
264
|
@cache_dataset
|
|
265
|
-
def load_sst2_dataset(self):
|
|
265
|
+
def load_sst2_dataset(self) -> Dataset:
|
|
266
266
|
"""
|
|
267
267
|
Load and tokenize the SST-2 dataset.
|
|
268
268
|
|
|
@@ -281,7 +281,7 @@ class TokenizedGLUE:
|
|
|
281
281
|
return dataset
|
|
282
282
|
|
|
283
283
|
@cache_dataset
|
|
284
|
-
def load_qnli_dataset(self):
|
|
284
|
+
def load_qnli_dataset(self) -> Dataset:
|
|
285
285
|
"""
|
|
286
286
|
Load and tokenize the QNLI dataset.
|
|
287
287
|
|
|
File without changes
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
import collections
|
|
7
|
+
import warnings
|
|
8
|
+
from io import BytesIO
|
|
9
|
+
|
|
10
|
+
import cv2 # pip install opencv-python
|
|
11
|
+
import numpy as np
|
|
12
|
+
import skimage as sk
|
|
13
|
+
import torch
|
|
14
|
+
import torchvision.transforms as trn
|
|
15
|
+
from PIL import Image
|
|
16
|
+
from PIL import Image as PILImage
|
|
17
|
+
from scipy.ndimage import zoom as scizoom
|
|
18
|
+
from scipy.ndimage.interpolation import map_coordinates
|
|
19
|
+
from skimage.filters import gaussian # pip install scikit-image
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from wand.api import library as wandlibrary
|
|
24
|
+
from wand.image import Image as WandImage
|
|
25
|
+
except ImportError as e:
|
|
26
|
+
logger.error(
|
|
27
|
+
"Failed to import wand."
|
|
28
|
+
"Install it with `apt-get install libmagickwand-dev` and `pip install Wand`"
|
|
29
|
+
"For more information, refer to the documentation https://docs.wand-py.org/"
|
|
30
|
+
)
|
|
31
|
+
raise e
|
|
32
|
+
|
|
33
|
+
# /////////////// Distortion Helpers ///////////////
|
|
34
|
+
|
|
35
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# /////////////// Distortions ///////////////
|
|
39
|
+
class MotionImage(WandImage):
|
|
40
|
+
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
|
41
|
+
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def gaussian_noise(x, severity=1):
|
|
45
|
+
c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
|
|
46
|
+
|
|
47
|
+
x = np.array(x) / 255.0
|
|
48
|
+
return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def impulse_noise(x, severity=1):
|
|
52
|
+
c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
|
|
53
|
+
|
|
54
|
+
x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
|
|
55
|
+
return np.clip(x, 0, 1) * 255
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def motion_blur(x, severity=1):
|
|
59
|
+
c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
|
|
60
|
+
|
|
61
|
+
output = BytesIO()
|
|
62
|
+
x.save(output, format="PNG")
|
|
63
|
+
x = MotionImage(blob=output.getvalue())
|
|
64
|
+
|
|
65
|
+
x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
|
|
66
|
+
|
|
67
|
+
x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
|
|
68
|
+
|
|
69
|
+
if x.shape != (32, 32):
|
|
70
|
+
return np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB
|
|
71
|
+
else: # greyscale to RGB
|
|
72
|
+
return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def spatter(x, severity=1):
|
|
76
|
+
c = [
|
|
77
|
+
(0.62, 0.1, 0.7, 0.7, 0.5, 0),
|
|
78
|
+
(0.65, 0.1, 0.8, 0.7, 0.5, 0),
|
|
79
|
+
(0.65, 0.3, 1, 0.69, 0.5, 0),
|
|
80
|
+
(0.65, 0.1, 0.7, 0.69, 0.6, 1),
|
|
81
|
+
(0.65, 0.1, 0.5, 0.68, 0.6, 1),
|
|
82
|
+
][severity - 1]
|
|
83
|
+
x = np.array(x, dtype=np.float32) / 255.0
|
|
84
|
+
|
|
85
|
+
liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
|
|
86
|
+
|
|
87
|
+
liquid_layer = gaussian(liquid_layer, sigma=c[2])
|
|
88
|
+
liquid_layer[liquid_layer < c[3]] = 0
|
|
89
|
+
if c[5] == 0:
|
|
90
|
+
liquid_layer = (liquid_layer * 255).astype(np.uint8)
|
|
91
|
+
dist = 255 - cv2.Canny(liquid_layer, 50, 150)
|
|
92
|
+
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
|
|
93
|
+
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
|
|
94
|
+
dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
|
|
95
|
+
dist = cv2.equalizeHist(dist)
|
|
96
|
+
# ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
|
|
97
|
+
# ker -= np.mean(ker)
|
|
98
|
+
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
|
|
99
|
+
dist = cv2.filter2D(dist, cv2.CV_8U, ker)
|
|
100
|
+
dist = cv2.blur(dist, (3, 3)).astype(np.float32)
|
|
101
|
+
|
|
102
|
+
m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
|
|
103
|
+
m /= np.max(m, axis=(0, 1))
|
|
104
|
+
m *= c[4]
|
|
105
|
+
|
|
106
|
+
# water is pale turqouise
|
|
107
|
+
color = np.concatenate(
|
|
108
|
+
(
|
|
109
|
+
175 / 255.0 * np.ones_like(m[..., :1]),
|
|
110
|
+
238 / 255.0 * np.ones_like(m[..., :1]),
|
|
111
|
+
238 / 255.0 * np.ones_like(m[..., :1]),
|
|
112
|
+
),
|
|
113
|
+
axis=2,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
|
|
117
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
|
|
118
|
+
|
|
119
|
+
return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
|
|
120
|
+
else:
|
|
121
|
+
m = np.where(liquid_layer > c[3], 1, 0)
|
|
122
|
+
m = gaussian(m.astype(np.float32), sigma=c[4])
|
|
123
|
+
m[m < 0.8] = 0
|
|
124
|
+
# m = np.abs(m) ** (1/c[4])
|
|
125
|
+
|
|
126
|
+
# mud brown
|
|
127
|
+
color = np.concatenate(
|
|
128
|
+
(
|
|
129
|
+
63 / 255.0 * np.ones_like(x[..., :1]),
|
|
130
|
+
42 / 255.0 * np.ones_like(x[..., :1]),
|
|
131
|
+
20 / 255.0 * np.ones_like(x[..., :1]),
|
|
132
|
+
),
|
|
133
|
+
axis=2,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
color *= m[..., np.newaxis]
|
|
137
|
+
x *= 1 - m[..., np.newaxis]
|
|
138
|
+
|
|
139
|
+
return np.clip(x + color, 0, 1) * 255
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def contrast(x, severity=1):
|
|
143
|
+
c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
|
|
144
|
+
|
|
145
|
+
x = np.array(x) / 255.0
|
|
146
|
+
means = np.mean(x, axis=(0, 1), keepdims=True)
|
|
147
|
+
return np.clip((x - means) * c + means, 0, 1) * 255
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def jpeg_compression(x, severity=1):
|
|
151
|
+
c = [80, 65, 58, 50, 40][severity - 1]
|
|
152
|
+
|
|
153
|
+
output = BytesIO()
|
|
154
|
+
x.save(output, "JPEG", quality=c)
|
|
155
|
+
x = PILImage.open(output)
|
|
156
|
+
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def pixelate(x, severity=1):
|
|
161
|
+
c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
|
|
162
|
+
|
|
163
|
+
x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
|
|
164
|
+
x = x.resize((32, 32), PILImage.BOX)
|
|
165
|
+
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# /////////////// End Distortions ///////////////
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
distortion_methods = collections.OrderedDict()
|
|
173
|
+
distortion_methods["Gaussian Noise"] = gaussian_noise
|
|
174
|
+
distortion_methods["Impulse Noise"] = impulse_noise
|
|
175
|
+
distortion_methods["Motion Blur"] = motion_blur
|
|
176
|
+
distortion_methods["Contrast"] = contrast
|
|
177
|
+
distortion_methods["Pixelate"] = pixelate
|
|
178
|
+
distortion_methods["JPEG"] = jpeg_compression
|
|
179
|
+
distortion_methods["Spatter"] = spatter
|
|
@@ -20,7 +20,7 @@ class TransformedImageDataset(Dataset):
|
|
|
20
20
|
transform (Callable): The transform to be applied to the images.
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
def __init__(self, dataset, transform: Callable):
|
|
23
|
+
def __init__(self, dataset: Dataset, transform: Callable):
|
|
24
24
|
super().__init__()
|
|
25
25
|
self.dataset = dataset
|
|
26
26
|
self.transform = transform
|
fusion_bench/dataset/nyuv2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import fnmatch
|
|
2
2
|
import os
|
|
3
|
-
from typing import Callable, Optional
|
|
3
|
+
from typing import Callable, Dict, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -68,7 +68,7 @@ class NYUv2(Dataset):
|
|
|
68
68
|
)
|
|
69
69
|
self.noise = torch.rand(self.data_len, 1, 288, 384)
|
|
70
70
|
|
|
71
|
-
def __getitem__(self, index):
|
|
71
|
+
def __getitem__(self, index) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
72
72
|
"""
|
|
73
73
|
Retrieve an item from the dataset.
|
|
74
74
|
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -37,11 +37,12 @@ _import_structure = {
|
|
|
37
37
|
"ties_merging": ["TiesMergingAlgorithm"],
|
|
38
38
|
"dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
|
|
39
39
|
"fisher_merging": [
|
|
40
|
+
"FisherMergingAlgorithm",
|
|
40
41
|
"FisherMergingForCLIPVisionModel",
|
|
41
42
|
"FisherMergingAlgorithmForGPT2",
|
|
42
43
|
],
|
|
43
44
|
"regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
|
|
44
|
-
"regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
|
|
45
|
+
"regmean_plusplus": ["RegMeanAlgorithmPlusPlus", "RegMeanAlgorithmForCLIPPlusPlus"],
|
|
45
46
|
"adamerging": [
|
|
46
47
|
"CLIPTaskWiseAdaMergingAlgorithm",
|
|
47
48
|
"CLIPLayerWiseAdaMergingAlgorithm",
|
|
@@ -69,6 +70,7 @@ _import_structure = {
|
|
|
69
70
|
"FlanT5LayerWiseGossipAlgorithm",
|
|
70
71
|
],
|
|
71
72
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
|
+
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
72
74
|
# plug-and-play model merging methods
|
|
73
75
|
"concrete_subspace": [
|
|
74
76
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -88,7 +90,10 @@ _import_structure = {
|
|
|
88
90
|
"MixtralForCausalLMMergingAlgorithm",
|
|
89
91
|
],
|
|
90
92
|
"dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
|
|
91
|
-
"we_moe": [
|
|
93
|
+
"we_moe": [
|
|
94
|
+
"CLIPWeightEnsemblingMoEAlgorithm",
|
|
95
|
+
"FlanT5WeightEnsemblingMoEAlgorithm",
|
|
96
|
+
],
|
|
92
97
|
"rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
|
|
93
98
|
"sparse_we_moe": [
|
|
94
99
|
"SparseWeightEnsemblingMoEAlgorithm",
|
|
@@ -99,6 +104,8 @@ _import_structure = {
|
|
|
99
104
|
"SmileUpscalingAlgorithm",
|
|
100
105
|
"SingularProjectionMergingAlgorithm",
|
|
101
106
|
],
|
|
107
|
+
# task vector compression methods
|
|
108
|
+
"bitdelta": ["BitDeltaAlgorithm"],
|
|
102
109
|
# pruning methods
|
|
103
110
|
"pruning": [
|
|
104
111
|
"MagnitudeDiffPruningAlgorithm",
|
|
@@ -126,6 +133,7 @@ if TYPE_CHECKING:
|
|
|
126
133
|
from .adamerging import *
|
|
127
134
|
from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
|
|
128
135
|
from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
136
|
+
from .bitdelta import BitDeltaAlgorithm
|
|
129
137
|
from .classification import (
|
|
130
138
|
ContinualImageClassificationFineTuningForCLIP,
|
|
131
139
|
ImageClassificationFineTuningForCLIP,
|
|
@@ -154,7 +162,11 @@ if TYPE_CHECKING:
|
|
|
154
162
|
LayerWisePruningForMixtral,
|
|
155
163
|
ProgressivePruningForMixtral,
|
|
156
164
|
)
|
|
157
|
-
from .fisher_merging import
|
|
165
|
+
from .fisher_merging import (
|
|
166
|
+
FisherMergingAlgorithm,
|
|
167
|
+
FisherMergingAlgorithmForGPT2,
|
|
168
|
+
FisherMergingForCLIPVisionModel,
|
|
169
|
+
)
|
|
158
170
|
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
|
|
159
171
|
from .gossip import (
|
|
160
172
|
CLIPLayerWiseGossipAlgorithm,
|
|
@@ -196,7 +208,10 @@ if TYPE_CHECKING:
|
|
|
196
208
|
)
|
|
197
209
|
from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
|
|
198
210
|
from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
|
|
199
|
-
from .regmean_plusplus import
|
|
211
|
+
from .regmean_plusplus import (
|
|
212
|
+
RegMeanAlgorithmForCLIPPlusPlus,
|
|
213
|
+
RegMeanAlgorithmPlusPlus,
|
|
214
|
+
)
|
|
200
215
|
from .simple_average import SimpleAverageAlgorithm
|
|
201
216
|
from .slerp import SlerpMergeAlgorithm
|
|
202
217
|
from .smile_upscaling import (
|
|
@@ -212,10 +227,14 @@ if TYPE_CHECKING:
|
|
|
212
227
|
PCPSparseLoForLlama,
|
|
213
228
|
SparseLoForLlama,
|
|
214
229
|
)
|
|
230
|
+
from .tall_mask import TallMaskTaskArithmeticAlgorithm
|
|
215
231
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
216
232
|
from .task_singular_vector import TaskSingularVectorMerging
|
|
217
233
|
from .ties_merging import TiesMergingAlgorithm
|
|
218
|
-
from .we_moe import
|
|
234
|
+
from .we_moe import (
|
|
235
|
+
CLIPWeightEnsemblingMoEAlgorithm,
|
|
236
|
+
FlanT5WeightEnsemblingMoEAlgorithm,
|
|
237
|
+
)
|
|
219
238
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
220
239
|
|
|
221
240
|
else:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
from typing import Iterator
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from omegaconf import DictConfig
|
|
@@ -42,7 +43,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
42
43
|
super().__init__(algorithm_config)
|
|
43
44
|
|
|
44
45
|
@functools.cache
|
|
45
|
-
def get_test_dataset(self, task: str):
|
|
46
|
+
def get_test_dataset(self, task: str) -> CLIPDataset:
|
|
46
47
|
"""
|
|
47
48
|
Load the test dataset for the task.
|
|
48
49
|
This method is cached, so the dataset is loaded only once.
|
|
@@ -59,7 +60,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
59
60
|
return dataset
|
|
60
61
|
|
|
61
62
|
@functools.cache
|
|
62
|
-
def get_shuffled_test_loader_iter(self, task: str):
|
|
63
|
+
def get_shuffled_test_loader_iter(self, task: str) -> Iterator:
|
|
63
64
|
"""
|
|
64
65
|
Get an iterator over the shuffled test DataLoader for the task.
|
|
65
66
|
|
|
@@ -88,11 +89,14 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
88
89
|
classification head for each task.
|
|
89
90
|
"""
|
|
90
91
|
clip_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
91
|
-
|
|
92
|
-
clip_model_config
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
92
|
+
if isinstance(clip_model_config, str):
|
|
93
|
+
pretrained_path = clip_model_config
|
|
94
|
+
else:
|
|
95
|
+
pretrained_path = (
|
|
96
|
+
clip_model_config.pretrained_model_name_or_path
|
|
97
|
+
if hasattr(clip_model_config, "pretrained_model_name_or_path")
|
|
98
|
+
else clip_model_config.path
|
|
99
|
+
)
|
|
96
100
|
|
|
97
101
|
with timeit_context("Loading CLIP processor and pretrained CLIP model."):
|
|
98
102
|
self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
|