fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/__init__.py CHANGED
@@ -19,8 +19,28 @@ from . import (
19
19
  tasks,
20
20
  utils,
21
21
  )
22
+ from .constants import RuntimeConstants
22
23
  from .method import BaseAlgorithm, BaseModelFusionAlgorithm
24
+ from .mixins import auto_register_config
23
25
  from .modelpool import BaseModelPool
24
- from .models import separate_io
26
+ from .models import (
27
+ create_default_model_card,
28
+ load_model_card_template,
29
+ save_pretrained_with_remote_code,
30
+ separate_io,
31
+ )
32
+ from .programs import BaseHydraProgram
25
33
  from .taskpool import BaseTaskPool
26
- from .utils import parse_dtype, print_parameters, timeit_context
34
+ from .utils import (
35
+ cache_with_joblib,
36
+ get_rankzero_logger,
37
+ import_object,
38
+ instantiate,
39
+ parse_dtype,
40
+ print_parameters,
41
+ seed_everything_by_time,
42
+ set_default_cache_dir,
43
+ set_print_function_call,
44
+ set_print_function_call_permeanent,
45
+ timeit_context,
46
+ )
@@ -0,0 +1,3 @@
1
+ """
2
+ Tutorial module for FusionBench
3
+ """
@@ -0,0 +1,49 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ from fusion_bench.programs import BaseHydraProgram
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class GreetingProgram(BaseHydraProgram):
12
+ """
13
+ A simple program that greets users with a custom message.
14
+ """
15
+
16
+ _config_mapping = BaseHydraProgram._config_mapping | {
17
+ "message": "message",
18
+ "name": "name",
19
+ "repeat_count": "repeat_count",
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ message: str = "Hello",
25
+ name: str = "World",
26
+ repeat_count: int = 1,
27
+ **kwargs,
28
+ ):
29
+ self.message = message
30
+ self.name = name
31
+ self.repeat_count = repeat_count
32
+ super().__init__(**kwargs)
33
+
34
+ def run(self):
35
+ """Execute the greeting workflow."""
36
+ log.info("Starting greeting program")
37
+
38
+ # Create the greeting
39
+ greeting = f"{self.message}, {self.name}!"
40
+
41
+ # Print the greeting multiple times
42
+ for i in range(self.repeat_count):
43
+ if self.repeat_count > 1:
44
+ print(f"[{i+1}/{self.repeat_count}] {greeting}")
45
+ else:
46
+ print(greeting)
47
+
48
+ log.info("Greeting program completed")
49
+ return greeting
@@ -36,6 +36,20 @@ class ModelFusionAlgorithm(ABC):
36
36
  algorithm_config = DictConfig({})
37
37
  self.config = algorithm_config
38
38
 
39
+ def on_run_start(self):
40
+ """
41
+ Hook method called at the start of the run.
42
+ Can be overridden by subclasses to perform initialization tasks.
43
+ """
44
+ pass
45
+
46
+ def on_run_end(self):
47
+ """
48
+ Hook method called at the end of the run.
49
+ Can be overridden by subclasses to perform cleanup tasks.
50
+ """
51
+ pass
52
+
39
53
  @abstractmethod
40
54
  def run(self, modelpool):
41
55
  """
@@ -1,2 +1,8 @@
1
1
  # flake8: noqa F401
2
+ import importlib.metadata
3
+
2
4
  from .paths import *
5
+ from .runtime import RuntimeConstants
6
+
7
+ # fusionbench version
8
+ FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
@@ -1,4 +1,5 @@
1
- # Constants for CLIP Vision Model Merging
1
+ "Constants for CLIP Vision Model Merging"
2
+
2
3
  TASK_NAMES_TA8 = [
3
4
  "sun397",
4
5
  "stanford-cars",
@@ -9,7 +10,23 @@ TASK_NAMES_TA8 = [
9
10
  "mnist",
10
11
  "dtd",
11
12
  ]
12
-
13
+ "The 8 tasks used in the Task Arithmetic paper."
14
+ TASK_NAMES_TALL8 = TASK_NAMES_TA8
15
+ "The 8 tasks used in the Tall Mask paper"
16
+ TASK_NAMES_TALL10 = TASK_NAMES_TA8 + ["oxford_flowers102", "pcam"]
17
+ TASK_NAMES_TALL12 = TASK_NAMES_TALL10 + [
18
+ "fer2013",
19
+ "oxford-iiit-pet",
20
+ ]
21
+ TASK_NAMES_TALL14 = TASK_NAMES_TALL12 + [
22
+ "stl10",
23
+ "cifar100",
24
+ ]
25
+ "The 14 tasks used in the TALL mask paper"
26
+ TASK_NAMES_TALL16 = TASK_NAMES_TALL14 + ["cifar10", "food101"]
27
+ TASK_NAMES_TALL18 = TASK_NAMES_TALL16 + ["fashion_mnist", "emnist_letters"]
28
+ TASK_NAMES_TALL20 = TASK_NAMES_TALL18 + ["kmnist", "rendered-sst2"]
29
+ "The 20 tasks used in the TALL mask paper"
13
30
  TASK_NAMES_TA8_CAP = [
14
31
  "SUN397",
15
32
  "Cars",
@@ -20,3 +37,10 @@ TASK_NAMES_TA8_CAP = [
20
37
  "MNIST",
21
38
  "DTD",
22
39
  ]
40
+ TASK_NAMES_TALL8_CAP = TASK_NAMES_TA8_CAP
41
+ TASK_NAMES_TALL10_CAP = TASK_NAMES_TALL8_CAP + ["Flowers102", "PCAM"]
42
+ TASK_NAMES_TALL12_CAP = TASK_NAMES_TALL10_CAP + ["FER2013", "OxfordIIITPet"]
43
+ TASK_NAMES_TALL14_CAP = TASK_NAMES_TALL12_CAP + ["STL10", "CIFAR100"]
44
+ TASK_NAMES_TALL16_CAP = TASK_NAMES_TALL14_CAP + ["CIFAR10", "Food101"]
45
+ TASK_NAMES_TALL18_CAP = TASK_NAMES_TALL16_CAP + ["FashionMNIST", "EMNIST"]
46
+ TASK_NAMES_TALL20_CAP = TASK_NAMES_TALL18_CAP + ["KMNIST", "RenderedSST2"]
@@ -7,10 +7,14 @@ log = logging.getLogger(__name__)
7
7
  __all__ = ["LIBRARY_PATH", "PROJECT_ROOT_PATH", "DEFAULT_CONFIG_PATH"]
8
8
 
9
9
  LIBRARY_PATH = Path(importlib.import_module("fusion_bench").__path__[0])
10
+ """Path to the library directory."""
11
+
10
12
  PROJECT_ROOT_PATH = LIBRARY_PATH.parent
13
+ """Path to the project root directory."""
11
14
 
12
15
  if (PROJECT_ROOT_PATH / "config").is_dir():
13
16
  DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "config"
17
+ """Path to the default config directory."""
14
18
  elif (PROJECT_ROOT_PATH / "fusion_bench_config").is_dir():
15
19
  DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "fusion_bench_config"
16
20
  else:
@@ -0,0 +1,57 @@
1
+ import threading
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+
6
+ class RuntimeConstants:
7
+ """
8
+ This class holds constants related to the runtime environment of the Fusion Bench framework.
9
+ It includes default values for cache directories and other runtime configurations.
10
+
11
+ Implemented as a thread-safe singleton to ensure consistent runtime configuration
12
+ across the entire application.
13
+ """
14
+
15
+ _instance: Optional["RuntimeConstants"] = None
16
+ _lock = threading.Lock()
17
+
18
+ def __new__(cls) -> "RuntimeConstants":
19
+ """Create a new instance using singleton pattern with thread safety."""
20
+ with cls._lock:
21
+ # Double-check locking pattern
22
+ if cls._instance is None:
23
+ cls._instance = super(RuntimeConstants, cls).__new__(cls)
24
+ cls._instance._initialized = False
25
+ return cls._instance
26
+
27
+ def __init__(self):
28
+ """Initialize the singleton instance only once."""
29
+ if not self._initialized:
30
+ # Add your runtime constants here
31
+ self._initialized = True
32
+
33
+ debug = False
34
+
35
+ @property
36
+ def cache_dir(self) -> Path:
37
+ from fusion_bench.utils.cache_utils import DEFAULT_CACHE_DIR
38
+
39
+ return DEFAULT_CACHE_DIR
40
+
41
+ @cache_dir.setter
42
+ def cache_dir(self, path: Union[str, Path]) -> None:
43
+ from fusion_bench.utils.cache_utils import set_default_cache_dir
44
+
45
+ set_default_cache_dir(path)
46
+
47
+ @property
48
+ def print_function_call(self) -> bool:
49
+ from fusion_bench.utils.instantiate_utils import PRINT_FUNCTION_CALL
50
+
51
+ return PRINT_FUNCTION_CALL
52
+
53
+ @print_function_call.setter
54
+ def print_function_call(self, enable: bool) -> None:
55
+ from fusion_bench.utils.instantiate_utils import set_print_function_call
56
+
57
+ set_print_function_call(enable)
@@ -5,6 +5,7 @@ This module provides a class to convert a dataset whose object is a list of dict
5
5
  from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
+ from torch.utils.data import Dataset
8
9
  from transformers import CLIPProcessor, ProcessorMixin
9
10
 
10
11
  __all__ = ["CLIPDataset"]
@@ -28,7 +29,7 @@ class CLIPDataset(torch.utils.data.Dataset):
28
29
  processor (CLIPProcessor): The CLIP processor used for image preprocessing.
29
30
  """
30
31
 
31
- def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
32
+ def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
32
33
  self.dataset = dataset
33
34
  self.processor = processor
34
35
 
@@ -16,7 +16,7 @@ from functools import partial
16
16
  from pathlib import Path
17
17
  from typing import Literal
18
18
 
19
- from datasets import load_dataset, load_from_disk
19
+ from datasets import Dataset, load_dataset, load_from_disk
20
20
  from transformers import PreTrainedTokenizer
21
21
 
22
22
 
@@ -147,7 +147,7 @@ class TokenizedGLUE:
147
147
  return glue_dataset_loaders[name]()
148
148
 
149
149
  @cache_dataset
150
- def load_mrpc_dataset(self):
150
+ def load_mrpc_dataset(self) -> Dataset:
151
151
  """
152
152
  Load and tokenize the MRPC dataset.
153
153
 
@@ -166,7 +166,7 @@ class TokenizedGLUE:
166
166
  return dataset
167
167
 
168
168
  @cache_dataset
169
- def load_rte_dataset(self):
169
+ def load_rte_dataset(self) -> Dataset:
170
170
  """
171
171
  Load and tokenize the RTE dataset.
172
172
 
@@ -186,7 +186,7 @@ class TokenizedGLUE:
186
186
  return dataset
187
187
 
188
188
  @cache_dataset
189
- def load_wnli_dataset(self):
189
+ def load_wnli_dataset(self) -> Dataset:
190
190
  """
191
191
  Load and tokenize the WNLI dataset.
192
192
 
@@ -205,7 +205,7 @@ class TokenizedGLUE:
205
205
  return dataset
206
206
 
207
207
  @cache_dataset
208
- def load_qqp_dataset(self):
208
+ def load_qqp_dataset(self) -> Dataset:
209
209
  """
210
210
  Load and tokenize the QQP dataset.
211
211
 
@@ -224,7 +224,7 @@ class TokenizedGLUE:
224
224
  return dataset
225
225
 
226
226
  @cache_dataset
227
- def load_mnli_dataset(self):
227
+ def load_mnli_dataset(self) -> Dataset:
228
228
  """
229
229
  Load and tokenize the MNLI dataset.
230
230
 
@@ -243,7 +243,7 @@ class TokenizedGLUE:
243
243
  return dataset
244
244
 
245
245
  @cache_dataset
246
- def load_cola_dataset(self):
246
+ def load_cola_dataset(self) -> Dataset:
247
247
  """
248
248
  Load and tokenize the CoLA dataset.
249
249
 
@@ -262,7 +262,7 @@ class TokenizedGLUE:
262
262
  return dataset
263
263
 
264
264
  @cache_dataset
265
- def load_sst2_dataset(self):
265
+ def load_sst2_dataset(self) -> Dataset:
266
266
  """
267
267
  Load and tokenize the SST-2 dataset.
268
268
 
@@ -281,7 +281,7 @@ class TokenizedGLUE:
281
281
  return dataset
282
282
 
283
283
  @cache_dataset
284
- def load_qnli_dataset(self):
284
+ def load_qnli_dataset(self) -> Dataset:
285
285
  """
286
286
  Load and tokenize the QNLI dataset.
287
287
 
File without changes
@@ -0,0 +1,179 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ import collections
7
+ import warnings
8
+ from io import BytesIO
9
+
10
+ import cv2 # pip install opencv-python
11
+ import numpy as np
12
+ import skimage as sk
13
+ import torch
14
+ import torchvision.transforms as trn
15
+ from PIL import Image
16
+ from PIL import Image as PILImage
17
+ from scipy.ndimage import zoom as scizoom
18
+ from scipy.ndimage.interpolation import map_coordinates
19
+ from skimage.filters import gaussian # pip install scikit-image
20
+ from tqdm import tqdm
21
+
22
+ try:
23
+ from wand.api import library as wandlibrary
24
+ from wand.image import Image as WandImage
25
+ except ImportError as e:
26
+ logger.error(
27
+ "Failed to import wand."
28
+ "Install it with `apt-get install libmagickwand-dev` and `pip install Wand`"
29
+ "For more information, refer to the documentation https://docs.wand-py.org/"
30
+ )
31
+ raise e
32
+
33
+ # /////////////// Distortion Helpers ///////////////
34
+
35
+ warnings.simplefilter("ignore", UserWarning)
36
+
37
+
38
+ # /////////////// Distortions ///////////////
39
+ class MotionImage(WandImage):
40
+ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
41
+ wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
42
+
43
+
44
+ def gaussian_noise(x, severity=1):
45
+ c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
46
+
47
+ x = np.array(x) / 255.0
48
+ return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
49
+
50
+
51
+ def impulse_noise(x, severity=1):
52
+ c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
53
+
54
+ x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
55
+ return np.clip(x, 0, 1) * 255
56
+
57
+
58
+ def motion_blur(x, severity=1):
59
+ c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
60
+
61
+ output = BytesIO()
62
+ x.save(output, format="PNG")
63
+ x = MotionImage(blob=output.getvalue())
64
+
65
+ x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
66
+
67
+ x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
68
+
69
+ if x.shape != (32, 32):
70
+ return np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB
71
+ else: # greyscale to RGB
72
+ return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)
73
+
74
+
75
+ def spatter(x, severity=1):
76
+ c = [
77
+ (0.62, 0.1, 0.7, 0.7, 0.5, 0),
78
+ (0.65, 0.1, 0.8, 0.7, 0.5, 0),
79
+ (0.65, 0.3, 1, 0.69, 0.5, 0),
80
+ (0.65, 0.1, 0.7, 0.69, 0.6, 1),
81
+ (0.65, 0.1, 0.5, 0.68, 0.6, 1),
82
+ ][severity - 1]
83
+ x = np.array(x, dtype=np.float32) / 255.0
84
+
85
+ liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
86
+
87
+ liquid_layer = gaussian(liquid_layer, sigma=c[2])
88
+ liquid_layer[liquid_layer < c[3]] = 0
89
+ if c[5] == 0:
90
+ liquid_layer = (liquid_layer * 255).astype(np.uint8)
91
+ dist = 255 - cv2.Canny(liquid_layer, 50, 150)
92
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
93
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
94
+ dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
95
+ dist = cv2.equalizeHist(dist)
96
+ # ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
97
+ # ker -= np.mean(ker)
98
+ ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
99
+ dist = cv2.filter2D(dist, cv2.CV_8U, ker)
100
+ dist = cv2.blur(dist, (3, 3)).astype(np.float32)
101
+
102
+ m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
103
+ m /= np.max(m, axis=(0, 1))
104
+ m *= c[4]
105
+
106
+ # water is pale turqouise
107
+ color = np.concatenate(
108
+ (
109
+ 175 / 255.0 * np.ones_like(m[..., :1]),
110
+ 238 / 255.0 * np.ones_like(m[..., :1]),
111
+ 238 / 255.0 * np.ones_like(m[..., :1]),
112
+ ),
113
+ axis=2,
114
+ )
115
+
116
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
117
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
118
+
119
+ return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
120
+ else:
121
+ m = np.where(liquid_layer > c[3], 1, 0)
122
+ m = gaussian(m.astype(np.float32), sigma=c[4])
123
+ m[m < 0.8] = 0
124
+ # m = np.abs(m) ** (1/c[4])
125
+
126
+ # mud brown
127
+ color = np.concatenate(
128
+ (
129
+ 63 / 255.0 * np.ones_like(x[..., :1]),
130
+ 42 / 255.0 * np.ones_like(x[..., :1]),
131
+ 20 / 255.0 * np.ones_like(x[..., :1]),
132
+ ),
133
+ axis=2,
134
+ )
135
+
136
+ color *= m[..., np.newaxis]
137
+ x *= 1 - m[..., np.newaxis]
138
+
139
+ return np.clip(x + color, 0, 1) * 255
140
+
141
+
142
+ def contrast(x, severity=1):
143
+ c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
144
+
145
+ x = np.array(x) / 255.0
146
+ means = np.mean(x, axis=(0, 1), keepdims=True)
147
+ return np.clip((x - means) * c + means, 0, 1) * 255
148
+
149
+
150
+ def jpeg_compression(x, severity=1):
151
+ c = [80, 65, 58, 50, 40][severity - 1]
152
+
153
+ output = BytesIO()
154
+ x.save(output, "JPEG", quality=c)
155
+ x = PILImage.open(output)
156
+
157
+ return x
158
+
159
+
160
+ def pixelate(x, severity=1):
161
+ c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
162
+
163
+ x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
164
+ x = x.resize((32, 32), PILImage.BOX)
165
+
166
+ return x
167
+
168
+
169
+ # /////////////// End Distortions ///////////////
170
+
171
+
172
+ distortion_methods = collections.OrderedDict()
173
+ distortion_methods["Gaussian Noise"] = gaussian_noise
174
+ distortion_methods["Impulse Noise"] = impulse_noise
175
+ distortion_methods["Motion Blur"] = motion_blur
176
+ distortion_methods["Contrast"] = contrast
177
+ distortion_methods["Pixelate"] = pixelate
178
+ distortion_methods["JPEG"] = jpeg_compression
179
+ distortion_methods["Spatter"] = spatter
@@ -20,7 +20,7 @@ class TransformedImageDataset(Dataset):
20
20
  transform (Callable): The transform to be applied to the images.
21
21
  """
22
22
 
23
- def __init__(self, dataset, transform: Callable):
23
+ def __init__(self, dataset: Dataset, transform: Callable):
24
24
  super().__init__()
25
25
  self.dataset = dataset
26
26
  self.transform = transform
@@ -1,6 +1,6 @@
1
1
  import fnmatch
2
2
  import os
3
- from typing import Callable, Optional
3
+ from typing import Callable, Dict, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -68,7 +68,7 @@ class NYUv2(Dataset):
68
68
  )
69
69
  self.noise = torch.rand(self.data_len, 1, 288, 384)
70
70
 
71
- def __getitem__(self, index):
71
+ def __getitem__(self, index) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
72
72
  """
73
73
  Retrieve an item from the dataset.
74
74
 
@@ -37,11 +37,12 @@ _import_structure = {
37
37
  "ties_merging": ["TiesMergingAlgorithm"],
38
38
  "dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
39
39
  "fisher_merging": [
40
+ "FisherMergingAlgorithm",
40
41
  "FisherMergingForCLIPVisionModel",
41
42
  "FisherMergingAlgorithmForGPT2",
42
43
  ],
43
44
  "regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
44
- "regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
45
+ "regmean_plusplus": ["RegMeanAlgorithmPlusPlus", "RegMeanAlgorithmForCLIPPlusPlus"],
45
46
  "adamerging": [
46
47
  "CLIPTaskWiseAdaMergingAlgorithm",
47
48
  "CLIPLayerWiseAdaMergingAlgorithm",
@@ -69,6 +70,7 @@ _import_structure = {
69
70
  "FlanT5LayerWiseGossipAlgorithm",
70
71
  ],
71
72
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
+ "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
72
74
  # plug-and-play model merging methods
73
75
  "concrete_subspace": [
74
76
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -88,7 +90,10 @@ _import_structure = {
88
90
  "MixtralForCausalLMMergingAlgorithm",
89
91
  ],
90
92
  "dawe": ["DataAdaptiveWeightEnsemblingForCLIP"],
91
- "we_moe": ["CLIPWeightEnsemblingMoEAlgorithm"],
93
+ "we_moe": [
94
+ "CLIPWeightEnsemblingMoEAlgorithm",
95
+ "FlanT5WeightEnsemblingMoEAlgorithm",
96
+ ],
92
97
  "rankone_moe": ["CLIPRankOneMoEAlgorithm", "RankOneMoEAlgorithm"],
93
98
  "sparse_we_moe": [
94
99
  "SparseWeightEnsemblingMoEAlgorithm",
@@ -99,6 +104,8 @@ _import_structure = {
99
104
  "SmileUpscalingAlgorithm",
100
105
  "SingularProjectionMergingAlgorithm",
101
106
  ],
107
+ # task vector compression methods
108
+ "bitdelta": ["BitDeltaAlgorithm"],
102
109
  # pruning methods
103
110
  "pruning": [
104
111
  "MagnitudeDiffPruningAlgorithm",
@@ -126,6 +133,7 @@ if TYPE_CHECKING:
126
133
  from .adamerging import *
127
134
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
128
135
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
136
+ from .bitdelta import BitDeltaAlgorithm
129
137
  from .classification import (
130
138
  ContinualImageClassificationFineTuningForCLIP,
131
139
  ImageClassificationFineTuningForCLIP,
@@ -154,7 +162,11 @@ if TYPE_CHECKING:
154
162
  LayerWisePruningForMixtral,
155
163
  ProgressivePruningForMixtral,
156
164
  )
157
- from .fisher_merging import FisherMergingForCLIPVisionModel
165
+ from .fisher_merging import (
166
+ FisherMergingAlgorithm,
167
+ FisherMergingAlgorithmForGPT2,
168
+ FisherMergingForCLIPVisionModel,
169
+ )
158
170
  from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
159
171
  from .gossip import (
160
172
  CLIPLayerWiseGossipAlgorithm,
@@ -196,7 +208,10 @@ if TYPE_CHECKING:
196
208
  )
197
209
  from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
198
210
  from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
199
- from .regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
211
+ from .regmean_plusplus import (
212
+ RegMeanAlgorithmForCLIPPlusPlus,
213
+ RegMeanAlgorithmPlusPlus,
214
+ )
200
215
  from .simple_average import SimpleAverageAlgorithm
201
216
  from .slerp import SlerpMergeAlgorithm
202
217
  from .smile_upscaling import (
@@ -212,10 +227,14 @@ if TYPE_CHECKING:
212
227
  PCPSparseLoForLlama,
213
228
  SparseLoForLlama,
214
229
  )
230
+ from .tall_mask import TallMaskTaskArithmeticAlgorithm
215
231
  from .task_arithmetic import TaskArithmeticAlgorithm
216
232
  from .task_singular_vector import TaskSingularVectorMerging
217
233
  from .ties_merging import TiesMergingAlgorithm
218
- from .we_moe import CLIPWeightEnsemblingMoEAlgorithm
234
+ from .we_moe import (
235
+ CLIPWeightEnsemblingMoEAlgorithm,
236
+ FlanT5WeightEnsemblingMoEAlgorithm,
237
+ )
219
238
  from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
220
239
 
221
240
  else:
@@ -3,7 +3,7 @@ Example Usage:
3
3
 
4
4
  ```bash
5
5
  fusion_bench \
6
- method=adamerging \
6
+ method=adamerging/clip \
7
7
  method.name=clip_layer_wise_adamerging \
8
8
  method.save_merging_weights=merging_weights.pt \
9
9
  modelpool=clip-vit-base-patch32_TA8 \
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
+ from typing import Iterator
4
5
 
5
6
  import torch
6
7
  from omegaconf import DictConfig
@@ -42,7 +43,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
42
43
  super().__init__(algorithm_config)
43
44
 
44
45
  @functools.cache
45
- def get_test_dataset(self, task: str):
46
+ def get_test_dataset(self, task: str) -> CLIPDataset:
46
47
  """
47
48
  Load the test dataset for the task.
48
49
  This method is cached, so the dataset is loaded only once.
@@ -59,7 +60,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
59
60
  return dataset
60
61
 
61
62
  @functools.cache
62
- def get_shuffled_test_loader_iter(self, task: str):
63
+ def get_shuffled_test_loader_iter(self, task: str) -> Iterator:
63
64
  """
64
65
  Get an iterator over the shuffled test DataLoader for the task.
65
66
 
@@ -88,11 +89,14 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
88
89
  classification head for each task.
89
90
  """
90
91
  clip_model_config = self.modelpool.get_model_config("_pretrained_")
91
- pretrained_path = (
92
- clip_model_config.pretrained_model_name_or_path
93
- if hasattr(clip_model_config, "pretrained_model_name_or_path")
94
- else clip_model_config.path
95
- )
92
+ if isinstance(clip_model_config, str):
93
+ pretrained_path = clip_model_config
94
+ else:
95
+ pretrained_path = (
96
+ clip_model_config.pretrained_model_name_or_path
97
+ if hasattr(clip_model_config, "pretrained_model_name_or_path")
98
+ else clip_model_config.path
99
+ )
96
100
 
97
101
  with timeit_context("Loading CLIP processor and pretrained CLIP model."):
98
102
  self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)