fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -3
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +2 -3
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +5 -9
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +5 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from omegaconf import DictConfig
|
|
5
|
+
|
|
6
|
+
from fusion_bench.programs import BaseHydraProgram
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GreetingProgram(BaseHydraProgram):
|
|
12
|
+
"""
|
|
13
|
+
A simple program that greets users with a custom message.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
_config_mapping = BaseHydraProgram._config_mapping | {
|
|
17
|
+
"message": "message",
|
|
18
|
+
"name": "name",
|
|
19
|
+
"repeat_count": "repeat_count",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
message: str = "Hello",
|
|
25
|
+
name: str = "World",
|
|
26
|
+
repeat_count: int = 1,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
self.message = message
|
|
30
|
+
self.name = name
|
|
31
|
+
self.repeat_count = repeat_count
|
|
32
|
+
super().__init__(**kwargs)
|
|
33
|
+
|
|
34
|
+
def run(self):
|
|
35
|
+
"""Execute the greeting workflow."""
|
|
36
|
+
log.info("Starting greeting program")
|
|
37
|
+
|
|
38
|
+
# Create the greeting
|
|
39
|
+
greeting = f"{self.message}, {self.name}!"
|
|
40
|
+
|
|
41
|
+
# Print the greeting multiple times
|
|
42
|
+
for i in range(self.repeat_count):
|
|
43
|
+
if self.repeat_count > 1:
|
|
44
|
+
print(f"[{i+1}/{self.repeat_count}] {greeting}")
|
|
45
|
+
else:
|
|
46
|
+
print(greeting)
|
|
47
|
+
|
|
48
|
+
log.info("Greeting program completed")
|
|
49
|
+
return greeting
|
|
@@ -36,6 +36,20 @@ class ModelFusionAlgorithm(ABC):
|
|
|
36
36
|
algorithm_config = DictConfig({})
|
|
37
37
|
self.config = algorithm_config
|
|
38
38
|
|
|
39
|
+
def on_run_start(self):
|
|
40
|
+
"""
|
|
41
|
+
Hook method called at the start of the run.
|
|
42
|
+
Can be overridden by subclasses to perform initialization tasks.
|
|
43
|
+
"""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def on_run_end(self):
|
|
47
|
+
"""
|
|
48
|
+
Hook method called at the end of the run.
|
|
49
|
+
Can be overridden by subclasses to perform cleanup tasks.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
39
53
|
@abstractmethod
|
|
40
54
|
def run(self, modelpool):
|
|
41
55
|
"""
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
"Constants for CLIP Vision Model Merging"
|
|
2
|
+
|
|
2
3
|
TASK_NAMES_TA8 = [
|
|
3
4
|
"sun397",
|
|
4
5
|
"stanford-cars",
|
|
@@ -9,7 +10,23 @@ TASK_NAMES_TA8 = [
|
|
|
9
10
|
"mnist",
|
|
10
11
|
"dtd",
|
|
11
12
|
]
|
|
12
|
-
|
|
13
|
+
"The 8 tasks used in the Task Arithmetic paper."
|
|
14
|
+
TASK_NAMES_TALL8 = TASK_NAMES_TA8
|
|
15
|
+
"The 8 tasks used in the Tall Mask paper"
|
|
16
|
+
TASK_NAMES_TALL10 = TASK_NAMES_TA8 + ["oxford_flowers102", "pcam"]
|
|
17
|
+
TASK_NAMES_TALL12 = TASK_NAMES_TALL10 + [
|
|
18
|
+
"fer2013",
|
|
19
|
+
"oxford-iiit-pet",
|
|
20
|
+
]
|
|
21
|
+
TASK_NAMES_TALL14 = TASK_NAMES_TALL12 + [
|
|
22
|
+
"stl10",
|
|
23
|
+
"cifar100",
|
|
24
|
+
]
|
|
25
|
+
"The 14 tasks used in the TALL mask paper"
|
|
26
|
+
TASK_NAMES_TALL16 = TASK_NAMES_TALL14 + ["cifar10", "food101"]
|
|
27
|
+
TASK_NAMES_TALL18 = TASK_NAMES_TALL16 + ["fashion_mnist", "emnist_letters"]
|
|
28
|
+
TASK_NAMES_TALL20 = TASK_NAMES_TALL18 + ["kmnist", "rendered-sst2"]
|
|
29
|
+
"The 20 tasks used in the TALL mask paper"
|
|
13
30
|
TASK_NAMES_TA8_CAP = [
|
|
14
31
|
"SUN397",
|
|
15
32
|
"Cars",
|
|
@@ -20,3 +37,10 @@ TASK_NAMES_TA8_CAP = [
|
|
|
20
37
|
"MNIST",
|
|
21
38
|
"DTD",
|
|
22
39
|
]
|
|
40
|
+
TASK_NAMES_TALL8_CAP = TASK_NAMES_TA8_CAP
|
|
41
|
+
TASK_NAMES_TALL10_CAP = TASK_NAMES_TALL8_CAP + ["Flowers102", "PCAM"]
|
|
42
|
+
TASK_NAMES_TALL12_CAP = TASK_NAMES_TALL10_CAP + ["FER2013", "OxfordIIITPet"]
|
|
43
|
+
TASK_NAMES_TALL14_CAP = TASK_NAMES_TALL12_CAP + ["STL10", "CIFAR100"]
|
|
44
|
+
TASK_NAMES_TALL16_CAP = TASK_NAMES_TALL14_CAP + ["CIFAR10", "Food101"]
|
|
45
|
+
TASK_NAMES_TALL18_CAP = TASK_NAMES_TALL16_CAP + ["FashionMNIST", "EMNIST"]
|
|
46
|
+
TASK_NAMES_TALL20_CAP = TASK_NAMES_TALL18_CAP + ["KMNIST", "RenderedSST2"]
|
fusion_bench/constants/paths.py
CHANGED
|
@@ -7,10 +7,14 @@ log = logging.getLogger(__name__)
|
|
|
7
7
|
__all__ = ["LIBRARY_PATH", "PROJECT_ROOT_PATH", "DEFAULT_CONFIG_PATH"]
|
|
8
8
|
|
|
9
9
|
LIBRARY_PATH = Path(importlib.import_module("fusion_bench").__path__[0])
|
|
10
|
+
"""Path to the library directory."""
|
|
11
|
+
|
|
10
12
|
PROJECT_ROOT_PATH = LIBRARY_PATH.parent
|
|
13
|
+
"""Path to the project root directory."""
|
|
11
14
|
|
|
12
15
|
if (PROJECT_ROOT_PATH / "config").is_dir():
|
|
13
16
|
DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "config"
|
|
17
|
+
"""Path to the default config directory."""
|
|
14
18
|
elif (PROJECT_ROOT_PATH / "fusion_bench_config").is_dir():
|
|
15
19
|
DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "fusion_bench_config"
|
|
16
20
|
else:
|
|
@@ -5,6 +5,7 @@ This module provides a class to convert a dataset whose object is a list of dict
|
|
|
5
5
|
from typing import Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
from torch.utils.data import Dataset
|
|
8
9
|
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
10
|
|
|
10
11
|
__all__ = ["CLIPDataset"]
|
|
@@ -28,7 +29,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
28
29
|
processor (CLIPProcessor): The CLIP processor used for image preprocessing.
|
|
29
30
|
"""
|
|
30
31
|
|
|
31
|
-
def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
|
|
32
|
+
def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
|
|
32
33
|
self.dataset = dataset
|
|
33
34
|
self.processor = processor
|
|
34
35
|
|
|
@@ -16,7 +16,7 @@ from functools import partial
|
|
|
16
16
|
from pathlib import Path
|
|
17
17
|
from typing import Literal
|
|
18
18
|
|
|
19
|
-
from datasets import load_dataset, load_from_disk
|
|
19
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
20
20
|
from transformers import PreTrainedTokenizer
|
|
21
21
|
|
|
22
22
|
|
|
@@ -147,7 +147,7 @@ class TokenizedGLUE:
|
|
|
147
147
|
return glue_dataset_loaders[name]()
|
|
148
148
|
|
|
149
149
|
@cache_dataset
|
|
150
|
-
def load_mrpc_dataset(self):
|
|
150
|
+
def load_mrpc_dataset(self) -> Dataset:
|
|
151
151
|
"""
|
|
152
152
|
Load and tokenize the MRPC dataset.
|
|
153
153
|
|
|
@@ -166,7 +166,7 @@ class TokenizedGLUE:
|
|
|
166
166
|
return dataset
|
|
167
167
|
|
|
168
168
|
@cache_dataset
|
|
169
|
-
def load_rte_dataset(self):
|
|
169
|
+
def load_rte_dataset(self) -> Dataset:
|
|
170
170
|
"""
|
|
171
171
|
Load and tokenize the RTE dataset.
|
|
172
172
|
|
|
@@ -186,7 +186,7 @@ class TokenizedGLUE:
|
|
|
186
186
|
return dataset
|
|
187
187
|
|
|
188
188
|
@cache_dataset
|
|
189
|
-
def load_wnli_dataset(self):
|
|
189
|
+
def load_wnli_dataset(self) -> Dataset:
|
|
190
190
|
"""
|
|
191
191
|
Load and tokenize the WNLI dataset.
|
|
192
192
|
|
|
@@ -205,7 +205,7 @@ class TokenizedGLUE:
|
|
|
205
205
|
return dataset
|
|
206
206
|
|
|
207
207
|
@cache_dataset
|
|
208
|
-
def load_qqp_dataset(self):
|
|
208
|
+
def load_qqp_dataset(self) -> Dataset:
|
|
209
209
|
"""
|
|
210
210
|
Load and tokenize the QQP dataset.
|
|
211
211
|
|
|
@@ -224,7 +224,7 @@ class TokenizedGLUE:
|
|
|
224
224
|
return dataset
|
|
225
225
|
|
|
226
226
|
@cache_dataset
|
|
227
|
-
def load_mnli_dataset(self):
|
|
227
|
+
def load_mnli_dataset(self) -> Dataset:
|
|
228
228
|
"""
|
|
229
229
|
Load and tokenize the MNLI dataset.
|
|
230
230
|
|
|
@@ -243,7 +243,7 @@ class TokenizedGLUE:
|
|
|
243
243
|
return dataset
|
|
244
244
|
|
|
245
245
|
@cache_dataset
|
|
246
|
-
def load_cola_dataset(self):
|
|
246
|
+
def load_cola_dataset(self) -> Dataset:
|
|
247
247
|
"""
|
|
248
248
|
Load and tokenize the CoLA dataset.
|
|
249
249
|
|
|
@@ -262,7 +262,7 @@ class TokenizedGLUE:
|
|
|
262
262
|
return dataset
|
|
263
263
|
|
|
264
264
|
@cache_dataset
|
|
265
|
-
def load_sst2_dataset(self):
|
|
265
|
+
def load_sst2_dataset(self) -> Dataset:
|
|
266
266
|
"""
|
|
267
267
|
Load and tokenize the SST-2 dataset.
|
|
268
268
|
|
|
@@ -281,7 +281,7 @@ class TokenizedGLUE:
|
|
|
281
281
|
return dataset
|
|
282
282
|
|
|
283
283
|
@cache_dataset
|
|
284
|
-
def load_qnli_dataset(self):
|
|
284
|
+
def load_qnli_dataset(self) -> Dataset:
|
|
285
285
|
"""
|
|
286
286
|
Load and tokenize the QNLI dataset.
|
|
287
287
|
|
|
File without changes
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
import collections
|
|
7
|
+
import warnings
|
|
8
|
+
from io import BytesIO
|
|
9
|
+
|
|
10
|
+
import cv2 # pip install opencv-python
|
|
11
|
+
import numpy as np
|
|
12
|
+
import skimage as sk
|
|
13
|
+
import torch
|
|
14
|
+
import torchvision.transforms as trn
|
|
15
|
+
from PIL import Image
|
|
16
|
+
from PIL import Image as PILImage
|
|
17
|
+
from scipy.ndimage import zoom as scizoom
|
|
18
|
+
from scipy.ndimage.interpolation import map_coordinates
|
|
19
|
+
from skimage.filters import gaussian # pip install scikit-image
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from wand.api import library as wandlibrary
|
|
24
|
+
from wand.image import Image as WandImage
|
|
25
|
+
except ImportError as e:
|
|
26
|
+
logger.error(
|
|
27
|
+
"Failed to import wand."
|
|
28
|
+
"Install it with `apt-get install libmagickwand-dev` and `pip install Wand`"
|
|
29
|
+
"For more information, refer to the documentation https://docs.wand-py.org/"
|
|
30
|
+
)
|
|
31
|
+
raise e
|
|
32
|
+
|
|
33
|
+
# /////////////// Distortion Helpers ///////////////
|
|
34
|
+
|
|
35
|
+
warnings.simplefilter("ignore", UserWarning)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# /////////////// Distortions ///////////////
|
|
39
|
+
class MotionImage(WandImage):
|
|
40
|
+
def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
|
|
41
|
+
wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def gaussian_noise(x, severity=1):
|
|
45
|
+
c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
|
|
46
|
+
|
|
47
|
+
x = np.array(x) / 255.0
|
|
48
|
+
return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def impulse_noise(x, severity=1):
|
|
52
|
+
c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
|
|
53
|
+
|
|
54
|
+
x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
|
|
55
|
+
return np.clip(x, 0, 1) * 255
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def motion_blur(x, severity=1):
|
|
59
|
+
c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
|
|
60
|
+
|
|
61
|
+
output = BytesIO()
|
|
62
|
+
x.save(output, format="PNG")
|
|
63
|
+
x = MotionImage(blob=output.getvalue())
|
|
64
|
+
|
|
65
|
+
x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
|
|
66
|
+
|
|
67
|
+
x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
|
|
68
|
+
|
|
69
|
+
if x.shape != (32, 32):
|
|
70
|
+
return np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB
|
|
71
|
+
else: # greyscale to RGB
|
|
72
|
+
return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def spatter(x, severity=1):
|
|
76
|
+
c = [
|
|
77
|
+
(0.62, 0.1, 0.7, 0.7, 0.5, 0),
|
|
78
|
+
(0.65, 0.1, 0.8, 0.7, 0.5, 0),
|
|
79
|
+
(0.65, 0.3, 1, 0.69, 0.5, 0),
|
|
80
|
+
(0.65, 0.1, 0.7, 0.69, 0.6, 1),
|
|
81
|
+
(0.65, 0.1, 0.5, 0.68, 0.6, 1),
|
|
82
|
+
][severity - 1]
|
|
83
|
+
x = np.array(x, dtype=np.float32) / 255.0
|
|
84
|
+
|
|
85
|
+
liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
|
|
86
|
+
|
|
87
|
+
liquid_layer = gaussian(liquid_layer, sigma=c[2])
|
|
88
|
+
liquid_layer[liquid_layer < c[3]] = 0
|
|
89
|
+
if c[5] == 0:
|
|
90
|
+
liquid_layer = (liquid_layer * 255).astype(np.uint8)
|
|
91
|
+
dist = 255 - cv2.Canny(liquid_layer, 50, 150)
|
|
92
|
+
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
|
|
93
|
+
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
|
|
94
|
+
dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
|
|
95
|
+
dist = cv2.equalizeHist(dist)
|
|
96
|
+
# ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
|
|
97
|
+
# ker -= np.mean(ker)
|
|
98
|
+
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
|
|
99
|
+
dist = cv2.filter2D(dist, cv2.CV_8U, ker)
|
|
100
|
+
dist = cv2.blur(dist, (3, 3)).astype(np.float32)
|
|
101
|
+
|
|
102
|
+
m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
|
|
103
|
+
m /= np.max(m, axis=(0, 1))
|
|
104
|
+
m *= c[4]
|
|
105
|
+
|
|
106
|
+
# water is pale turqouise
|
|
107
|
+
color = np.concatenate(
|
|
108
|
+
(
|
|
109
|
+
175 / 255.0 * np.ones_like(m[..., :1]),
|
|
110
|
+
238 / 255.0 * np.ones_like(m[..., :1]),
|
|
111
|
+
238 / 255.0 * np.ones_like(m[..., :1]),
|
|
112
|
+
),
|
|
113
|
+
axis=2,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
|
|
117
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
|
|
118
|
+
|
|
119
|
+
return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
|
|
120
|
+
else:
|
|
121
|
+
m = np.where(liquid_layer > c[3], 1, 0)
|
|
122
|
+
m = gaussian(m.astype(np.float32), sigma=c[4])
|
|
123
|
+
m[m < 0.8] = 0
|
|
124
|
+
# m = np.abs(m) ** (1/c[4])
|
|
125
|
+
|
|
126
|
+
# mud brown
|
|
127
|
+
color = np.concatenate(
|
|
128
|
+
(
|
|
129
|
+
63 / 255.0 * np.ones_like(x[..., :1]),
|
|
130
|
+
42 / 255.0 * np.ones_like(x[..., :1]),
|
|
131
|
+
20 / 255.0 * np.ones_like(x[..., :1]),
|
|
132
|
+
),
|
|
133
|
+
axis=2,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
color *= m[..., np.newaxis]
|
|
137
|
+
x *= 1 - m[..., np.newaxis]
|
|
138
|
+
|
|
139
|
+
return np.clip(x + color, 0, 1) * 255
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def contrast(x, severity=1):
|
|
143
|
+
c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
|
|
144
|
+
|
|
145
|
+
x = np.array(x) / 255.0
|
|
146
|
+
means = np.mean(x, axis=(0, 1), keepdims=True)
|
|
147
|
+
return np.clip((x - means) * c + means, 0, 1) * 255
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def jpeg_compression(x, severity=1):
|
|
151
|
+
c = [80, 65, 58, 50, 40][severity - 1]
|
|
152
|
+
|
|
153
|
+
output = BytesIO()
|
|
154
|
+
x.save(output, "JPEG", quality=c)
|
|
155
|
+
x = PILImage.open(output)
|
|
156
|
+
|
|
157
|
+
return x
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def pixelate(x, severity=1):
|
|
161
|
+
c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
|
|
162
|
+
|
|
163
|
+
x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
|
|
164
|
+
x = x.resize((32, 32), PILImage.BOX)
|
|
165
|
+
|
|
166
|
+
return x
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# /////////////// End Distortions ///////////////
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
distortion_methods = collections.OrderedDict()
|
|
173
|
+
distortion_methods["Gaussian Noise"] = gaussian_noise
|
|
174
|
+
distortion_methods["Impulse Noise"] = impulse_noise
|
|
175
|
+
distortion_methods["Motion Blur"] = motion_blur
|
|
176
|
+
distortion_methods["Contrast"] = contrast
|
|
177
|
+
distortion_methods["Pixelate"] = pixelate
|
|
178
|
+
distortion_methods["JPEG"] = jpeg_compression
|
|
179
|
+
distortion_methods["Spatter"] = spatter
|
|
@@ -20,7 +20,7 @@ class TransformedImageDataset(Dataset):
|
|
|
20
20
|
transform (Callable): The transform to be applied to the images.
|
|
21
21
|
"""
|
|
22
22
|
|
|
23
|
-
def __init__(self, dataset, transform: Callable):
|
|
23
|
+
def __init__(self, dataset: Dataset, transform: Callable):
|
|
24
24
|
super().__init__()
|
|
25
25
|
self.dataset = dataset
|
|
26
26
|
self.transform = transform
|
fusion_bench/dataset/nyuv2.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import fnmatch
|
|
2
2
|
import os
|
|
3
|
-
from typing import Callable, Optional
|
|
3
|
+
from typing import Callable, Dict, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -68,7 +68,7 @@ class NYUv2(Dataset):
|
|
|
68
68
|
)
|
|
69
69
|
self.noise = torch.rand(self.data_len, 1, 288, 384)
|
|
70
70
|
|
|
71
|
-
def __getitem__(self, index):
|
|
71
|
+
def __getitem__(self, index) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
72
72
|
"""
|
|
73
73
|
Retrieve an item from the dataset.
|
|
74
74
|
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -37,11 +37,12 @@ _import_structure = {
|
|
|
37
37
|
"ties_merging": ["TiesMergingAlgorithm"],
|
|
38
38
|
"dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
|
|
39
39
|
"fisher_merging": [
|
|
40
|
+
"FisherMergingAlgorithm",
|
|
40
41
|
"FisherMergingForCLIPVisionModel",
|
|
41
42
|
"FisherMergingAlgorithmForGPT2",
|
|
42
43
|
],
|
|
43
44
|
"regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
|
|
44
|
-
"regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
|
|
45
|
+
"regmean_plusplus": ["RegMeanAlgorithmPlusPlus", "RegMeanAlgorithmForCLIPPlusPlus"],
|
|
45
46
|
"adamerging": [
|
|
46
47
|
"CLIPTaskWiseAdaMergingAlgorithm",
|
|
47
48
|
"CLIPLayerWiseAdaMergingAlgorithm",
|
|
@@ -69,6 +70,7 @@ _import_structure = {
|
|
|
69
70
|
"FlanT5LayerWiseGossipAlgorithm",
|
|
70
71
|
],
|
|
71
72
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
|
+
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
72
74
|
# plug-and-play model merging methods
|
|
73
75
|
"concrete_subspace": [
|
|
74
76
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -99,6 +101,8 @@ _import_structure = {
|
|
|
99
101
|
"SmileUpscalingAlgorithm",
|
|
100
102
|
"SingularProjectionMergingAlgorithm",
|
|
101
103
|
],
|
|
104
|
+
# task vector compression methods
|
|
105
|
+
"bitdelta": ["BitDeltaAlgorithm"],
|
|
102
106
|
# pruning methods
|
|
103
107
|
"pruning": [
|
|
104
108
|
"MagnitudeDiffPruningAlgorithm",
|
|
@@ -126,6 +130,7 @@ if TYPE_CHECKING:
|
|
|
126
130
|
from .adamerging import *
|
|
127
131
|
from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
|
|
128
132
|
from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
133
|
+
from .bitdelta import BitDeltaAlgorithm
|
|
129
134
|
from .classification import (
|
|
130
135
|
ContinualImageClassificationFineTuningForCLIP,
|
|
131
136
|
ImageClassificationFineTuningForCLIP,
|
|
@@ -154,7 +159,11 @@ if TYPE_CHECKING:
|
|
|
154
159
|
LayerWisePruningForMixtral,
|
|
155
160
|
ProgressivePruningForMixtral,
|
|
156
161
|
)
|
|
157
|
-
from .fisher_merging import
|
|
162
|
+
from .fisher_merging import (
|
|
163
|
+
FisherMergingAlgorithm,
|
|
164
|
+
FisherMergingAlgorithmForGPT2,
|
|
165
|
+
FisherMergingForCLIPVisionModel,
|
|
166
|
+
)
|
|
158
167
|
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
|
|
159
168
|
from .gossip import (
|
|
160
169
|
CLIPLayerWiseGossipAlgorithm,
|
|
@@ -196,7 +205,10 @@ if TYPE_CHECKING:
|
|
|
196
205
|
)
|
|
197
206
|
from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
|
|
198
207
|
from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
|
|
199
|
-
from .regmean_plusplus import
|
|
208
|
+
from .regmean_plusplus import (
|
|
209
|
+
RegMeanAlgorithmForCLIPPlusPlus,
|
|
210
|
+
RegMeanAlgorithmPlusPlus,
|
|
211
|
+
)
|
|
200
212
|
from .simple_average import SimpleAverageAlgorithm
|
|
201
213
|
from .slerp import SlerpMergeAlgorithm
|
|
202
214
|
from .smile_upscaling import (
|
|
@@ -212,6 +224,7 @@ if TYPE_CHECKING:
|
|
|
212
224
|
PCPSparseLoForLlama,
|
|
213
225
|
SparseLoForLlama,
|
|
214
226
|
)
|
|
227
|
+
from .tall_mask import TallMaskTaskArithmeticAlgorithm
|
|
215
228
|
from .task_arithmetic import TaskArithmeticAlgorithm
|
|
216
229
|
from .task_singular_vector import TaskSingularVectorMerging
|
|
217
230
|
from .ties_merging import TiesMergingAlgorithm
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
from typing import Iterator
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from omegaconf import DictConfig
|
|
@@ -42,7 +43,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
42
43
|
super().__init__(algorithm_config)
|
|
43
44
|
|
|
44
45
|
@functools.cache
|
|
45
|
-
def get_test_dataset(self, task: str):
|
|
46
|
+
def get_test_dataset(self, task: str) -> CLIPDataset:
|
|
46
47
|
"""
|
|
47
48
|
Load the test dataset for the task.
|
|
48
49
|
This method is cached, so the dataset is loaded only once.
|
|
@@ -59,7 +60,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
59
60
|
return dataset
|
|
60
61
|
|
|
61
62
|
@functools.cache
|
|
62
|
-
def get_shuffled_test_loader_iter(self, task: str):
|
|
63
|
+
def get_shuffled_test_loader_iter(self, task: str) -> Iterator:
|
|
63
64
|
"""
|
|
64
65
|
Get an iterator over the shuffled test DataLoader for the task.
|
|
65
66
|
|
|
@@ -88,11 +89,14 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
|
|
|
88
89
|
classification head for each task.
|
|
89
90
|
"""
|
|
90
91
|
clip_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
91
|
-
|
|
92
|
-
clip_model_config
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
92
|
+
if isinstance(clip_model_config, str):
|
|
93
|
+
pretrained_path = clip_model_config
|
|
94
|
+
else:
|
|
95
|
+
pretrained_path = (
|
|
96
|
+
clip_model_config.pretrained_model_name_or_path
|
|
97
|
+
if hasattr(clip_model_config, "pretrained_model_name_or_path")
|
|
98
|
+
else clip_model_config.path
|
|
99
|
+
)
|
|
96
100
|
|
|
97
101
|
with timeit_context("Loading CLIP processor and pretrained CLIP model."):
|
|
98
102
|
self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
|
|
@@ -31,9 +31,9 @@ log = logging.getLogger(__name__)
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
class LayerWiseAdaMergingAlgorithm(
|
|
34
|
-
ModelFusionAlgorithm,
|
|
35
34
|
LightningFabricMixin,
|
|
36
35
|
SimpleProfilerMixin,
|
|
36
|
+
ModelFusionAlgorithm,
|
|
37
37
|
):
|
|
38
38
|
_program: "FabricModelFusionProgram"
|
|
39
39
|
"""The program that this algorithm is running on."""
|
|
@@ -55,7 +55,9 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
55
55
|
super().__init__(algorithm_config)
|
|
56
56
|
|
|
57
57
|
@torch.no_grad()
|
|
58
|
-
def construct_layer_wise_merged_model(
|
|
58
|
+
def construct_layer_wise_merged_model(
|
|
59
|
+
self, modelpool: "ModelPool"
|
|
60
|
+
) -> LayerWiseMergedModel:
|
|
59
61
|
"""
|
|
60
62
|
Constructs a wrapped layer-wise merged model from model pool.
|
|
61
63
|
|
|
@@ -125,7 +127,7 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
125
127
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
126
128
|
torch.save(merging_weights.detach().cpu(), save_path)
|
|
127
129
|
|
|
128
|
-
def run(self, modelpool: ModelPool, **kwargs):
|
|
130
|
+
def run(self, modelpool: ModelPool, **kwargs) -> nn.Module:
|
|
129
131
|
"""
|
|
130
132
|
Run the Layer-Wise AdaMerging Algorithm.
|
|
131
133
|
|
|
@@ -176,7 +178,9 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
176
178
|
pass
|
|
177
179
|
|
|
178
180
|
@abstractmethod
|
|
179
|
-
def compute_logits(
|
|
181
|
+
def compute_logits(
|
|
182
|
+
self, module: LayerWiseMergedModel, images: Tensor, task: str
|
|
183
|
+
) -> Tensor:
|
|
180
184
|
"""
|
|
181
185
|
Compute the logits for the given images and task.
|
|
182
186
|
|
|
@@ -190,7 +194,9 @@ class LayerWiseAdaMergingAlgorithm(
|
|
|
190
194
|
"""
|
|
191
195
|
pass
|
|
192
196
|
|
|
193
|
-
def test_time_adaptation(
|
|
197
|
+
def test_time_adaptation(
|
|
198
|
+
self, module: "LayerWiseMergedModel[TorchModelType]"
|
|
199
|
+
) -> "LayerWiseMergedModel[TorchModelType]":
|
|
194
200
|
"""
|
|
195
201
|
Perform test-time adaptation on the merged model.
|
|
196
202
|
|