fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -3
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +2 -3
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  56. fusion_bench/method/simple_average.py +5 -9
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +5 -5
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/modelscope.py +127 -8
  122. fusion_bench/utils/parameters.py +2 -2
  123. fusion_bench/utils/rich_utils.py +3 -0
  124. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  125. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
  126. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
  127. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  128. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  129. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  130. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  131. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  132. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  133. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  134. fusion_bench_config/hydra/default.yaml +6 -2
  135. fusion_bench_config/llama_full_finetune.yaml +1 -0
  136. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  137. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  138. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  139. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  140. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  141. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  142. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  143. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  144. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  148. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  149. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  150. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  151. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  152. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  153. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  154. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  155. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  156. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  157. fusion_bench_config/nyuv2_config.yaml +3 -1
  158. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  159. fusion_bench_config/path/default.yaml +28 -0
  160. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  161. fusion_bench_config/method/adamerging.yaml +0 -23
  162. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  163. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  164. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  165. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  166. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  167. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  168. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  169. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/__init__.py CHANGED
@@ -20,6 +20,7 @@ from . import (
20
20
  utils,
21
21
  )
22
22
  from .method import BaseAlgorithm, BaseModelFusionAlgorithm
23
+ from .mixins import auto_register_config
23
24
  from .modelpool import BaseModelPool
24
25
  from .models import separate_io
25
26
  from .taskpool import BaseTaskPool
@@ -0,0 +1,3 @@
1
+ """
2
+ Tutorial module for FusionBench
3
+ """
@@ -0,0 +1,49 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ from fusion_bench.programs import BaseHydraProgram
7
+
8
+ log = logging.getLogger(__name__)
9
+
10
+
11
+ class GreetingProgram(BaseHydraProgram):
12
+ """
13
+ A simple program that greets users with a custom message.
14
+ """
15
+
16
+ _config_mapping = BaseHydraProgram._config_mapping | {
17
+ "message": "message",
18
+ "name": "name",
19
+ "repeat_count": "repeat_count",
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ message: str = "Hello",
25
+ name: str = "World",
26
+ repeat_count: int = 1,
27
+ **kwargs,
28
+ ):
29
+ self.message = message
30
+ self.name = name
31
+ self.repeat_count = repeat_count
32
+ super().__init__(**kwargs)
33
+
34
+ def run(self):
35
+ """Execute the greeting workflow."""
36
+ log.info("Starting greeting program")
37
+
38
+ # Create the greeting
39
+ greeting = f"{self.message}, {self.name}!"
40
+
41
+ # Print the greeting multiple times
42
+ for i in range(self.repeat_count):
43
+ if self.repeat_count > 1:
44
+ print(f"[{i+1}/{self.repeat_count}] {greeting}")
45
+ else:
46
+ print(greeting)
47
+
48
+ log.info("Greeting program completed")
49
+ return greeting
@@ -36,6 +36,20 @@ class ModelFusionAlgorithm(ABC):
36
36
  algorithm_config = DictConfig({})
37
37
  self.config = algorithm_config
38
38
 
39
+ def on_run_start(self):
40
+ """
41
+ Hook method called at the start of the run.
42
+ Can be overridden by subclasses to perform initialization tasks.
43
+ """
44
+ pass
45
+
46
+ def on_run_end(self):
47
+ """
48
+ Hook method called at the end of the run.
49
+ Can be overridden by subclasses to perform cleanup tasks.
50
+ """
51
+ pass
52
+
39
53
  @abstractmethod
40
54
  def run(self, modelpool):
41
55
  """
@@ -1,2 +1,7 @@
1
1
  # flake8: noqa F401
2
+ import importlib.metadata
3
+
2
4
  from .paths import *
5
+
6
+ # fusionbench version
7
+ FUSION_BENCH_VERSION = importlib.metadata.version("fusion-bench")
@@ -1,4 +1,5 @@
1
- # Constants for CLIP Vision Model Merging
1
+ "Constants for CLIP Vision Model Merging"
2
+
2
3
  TASK_NAMES_TA8 = [
3
4
  "sun397",
4
5
  "stanford-cars",
@@ -9,7 +10,23 @@ TASK_NAMES_TA8 = [
9
10
  "mnist",
10
11
  "dtd",
11
12
  ]
12
-
13
+ "The 8 tasks used in the Task Arithmetic paper."
14
+ TASK_NAMES_TALL8 = TASK_NAMES_TA8
15
+ "The 8 tasks used in the Tall Mask paper"
16
+ TASK_NAMES_TALL10 = TASK_NAMES_TA8 + ["oxford_flowers102", "pcam"]
17
+ TASK_NAMES_TALL12 = TASK_NAMES_TALL10 + [
18
+ "fer2013",
19
+ "oxford-iiit-pet",
20
+ ]
21
+ TASK_NAMES_TALL14 = TASK_NAMES_TALL12 + [
22
+ "stl10",
23
+ "cifar100",
24
+ ]
25
+ "The 14 tasks used in the TALL mask paper"
26
+ TASK_NAMES_TALL16 = TASK_NAMES_TALL14 + ["cifar10", "food101"]
27
+ TASK_NAMES_TALL18 = TASK_NAMES_TALL16 + ["fashion_mnist", "emnist_letters"]
28
+ TASK_NAMES_TALL20 = TASK_NAMES_TALL18 + ["kmnist", "rendered-sst2"]
29
+ "The 20 tasks used in the TALL mask paper"
13
30
  TASK_NAMES_TA8_CAP = [
14
31
  "SUN397",
15
32
  "Cars",
@@ -20,3 +37,10 @@ TASK_NAMES_TA8_CAP = [
20
37
  "MNIST",
21
38
  "DTD",
22
39
  ]
40
+ TASK_NAMES_TALL8_CAP = TASK_NAMES_TA8_CAP
41
+ TASK_NAMES_TALL10_CAP = TASK_NAMES_TALL8_CAP + ["Flowers102", "PCAM"]
42
+ TASK_NAMES_TALL12_CAP = TASK_NAMES_TALL10_CAP + ["FER2013", "OxfordIIITPet"]
43
+ TASK_NAMES_TALL14_CAP = TASK_NAMES_TALL12_CAP + ["STL10", "CIFAR100"]
44
+ TASK_NAMES_TALL16_CAP = TASK_NAMES_TALL14_CAP + ["CIFAR10", "Food101"]
45
+ TASK_NAMES_TALL18_CAP = TASK_NAMES_TALL16_CAP + ["FashionMNIST", "EMNIST"]
46
+ TASK_NAMES_TALL20_CAP = TASK_NAMES_TALL18_CAP + ["KMNIST", "RenderedSST2"]
@@ -7,10 +7,14 @@ log = logging.getLogger(__name__)
7
7
  __all__ = ["LIBRARY_PATH", "PROJECT_ROOT_PATH", "DEFAULT_CONFIG_PATH"]
8
8
 
9
9
  LIBRARY_PATH = Path(importlib.import_module("fusion_bench").__path__[0])
10
+ """Path to the library directory."""
11
+
10
12
  PROJECT_ROOT_PATH = LIBRARY_PATH.parent
13
+ """Path to the project root directory."""
11
14
 
12
15
  if (PROJECT_ROOT_PATH / "config").is_dir():
13
16
  DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "config"
17
+ """Path to the default config directory."""
14
18
  elif (PROJECT_ROOT_PATH / "fusion_bench_config").is_dir():
15
19
  DEFAULT_CONFIG_PATH = PROJECT_ROOT_PATH / "fusion_bench_config"
16
20
  else:
@@ -5,6 +5,7 @@ This module provides a class to convert a dataset whose object is a list of dict
5
5
  from typing import Optional, Tuple
6
6
 
7
7
  import torch
8
+ from torch.utils.data import Dataset
8
9
  from transformers import CLIPProcessor, ProcessorMixin
9
10
 
10
11
  __all__ = ["CLIPDataset"]
@@ -28,7 +29,7 @@ class CLIPDataset(torch.utils.data.Dataset):
28
29
  processor (CLIPProcessor): The CLIP processor used for image preprocessing.
29
30
  """
30
31
 
31
- def __init__(self, dataset, processor: Optional[CLIPProcessor] = None):
32
+ def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
32
33
  self.dataset = dataset
33
34
  self.processor = processor
34
35
 
@@ -16,7 +16,7 @@ from functools import partial
16
16
  from pathlib import Path
17
17
  from typing import Literal
18
18
 
19
- from datasets import load_dataset, load_from_disk
19
+ from datasets import Dataset, load_dataset, load_from_disk
20
20
  from transformers import PreTrainedTokenizer
21
21
 
22
22
 
@@ -147,7 +147,7 @@ class TokenizedGLUE:
147
147
  return glue_dataset_loaders[name]()
148
148
 
149
149
  @cache_dataset
150
- def load_mrpc_dataset(self):
150
+ def load_mrpc_dataset(self) -> Dataset:
151
151
  """
152
152
  Load and tokenize the MRPC dataset.
153
153
 
@@ -166,7 +166,7 @@ class TokenizedGLUE:
166
166
  return dataset
167
167
 
168
168
  @cache_dataset
169
- def load_rte_dataset(self):
169
+ def load_rte_dataset(self) -> Dataset:
170
170
  """
171
171
  Load and tokenize the RTE dataset.
172
172
 
@@ -186,7 +186,7 @@ class TokenizedGLUE:
186
186
  return dataset
187
187
 
188
188
  @cache_dataset
189
- def load_wnli_dataset(self):
189
+ def load_wnli_dataset(self) -> Dataset:
190
190
  """
191
191
  Load and tokenize the WNLI dataset.
192
192
 
@@ -205,7 +205,7 @@ class TokenizedGLUE:
205
205
  return dataset
206
206
 
207
207
  @cache_dataset
208
- def load_qqp_dataset(self):
208
+ def load_qqp_dataset(self) -> Dataset:
209
209
  """
210
210
  Load and tokenize the QQP dataset.
211
211
 
@@ -224,7 +224,7 @@ class TokenizedGLUE:
224
224
  return dataset
225
225
 
226
226
  @cache_dataset
227
- def load_mnli_dataset(self):
227
+ def load_mnli_dataset(self) -> Dataset:
228
228
  """
229
229
  Load and tokenize the MNLI dataset.
230
230
 
@@ -243,7 +243,7 @@ class TokenizedGLUE:
243
243
  return dataset
244
244
 
245
245
  @cache_dataset
246
- def load_cola_dataset(self):
246
+ def load_cola_dataset(self) -> Dataset:
247
247
  """
248
248
  Load and tokenize the CoLA dataset.
249
249
 
@@ -262,7 +262,7 @@ class TokenizedGLUE:
262
262
  return dataset
263
263
 
264
264
  @cache_dataset
265
- def load_sst2_dataset(self):
265
+ def load_sst2_dataset(self) -> Dataset:
266
266
  """
267
267
  Load and tokenize the SST-2 dataset.
268
268
 
@@ -281,7 +281,7 @@ class TokenizedGLUE:
281
281
  return dataset
282
282
 
283
283
  @cache_dataset
284
- def load_qnli_dataset(self):
284
+ def load_qnli_dataset(self) -> Dataset:
285
285
  """
286
286
  Load and tokenize the QNLI dataset.
287
287
 
File without changes
@@ -0,0 +1,179 @@
1
+ # -*- coding: utf-8 -*-
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ import collections
7
+ import warnings
8
+ from io import BytesIO
9
+
10
+ import cv2 # pip install opencv-python
11
+ import numpy as np
12
+ import skimage as sk
13
+ import torch
14
+ import torchvision.transforms as trn
15
+ from PIL import Image
16
+ from PIL import Image as PILImage
17
+ from scipy.ndimage import zoom as scizoom
18
+ from scipy.ndimage.interpolation import map_coordinates
19
+ from skimage.filters import gaussian # pip install scikit-image
20
+ from tqdm import tqdm
21
+
22
+ try:
23
+ from wand.api import library as wandlibrary
24
+ from wand.image import Image as WandImage
25
+ except ImportError as e:
26
+ logger.error(
27
+ "Failed to import wand."
28
+ "Install it with `apt-get install libmagickwand-dev` and `pip install Wand`"
29
+ "For more information, refer to the documentation https://docs.wand-py.org/"
30
+ )
31
+ raise e
32
+
33
+ # /////////////// Distortion Helpers ///////////////
34
+
35
+ warnings.simplefilter("ignore", UserWarning)
36
+
37
+
38
+ # /////////////// Distortions ///////////////
39
+ class MotionImage(WandImage):
40
+ def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
41
+ wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)
42
+
43
+
44
+ def gaussian_noise(x, severity=1):
45
+ c = [0.04, 0.06, 0.08, 0.09, 0.10][severity - 1]
46
+
47
+ x = np.array(x) / 255.0
48
+ return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
49
+
50
+
51
+ def impulse_noise(x, severity=1):
52
+ c = [0.01, 0.02, 0.03, 0.05, 0.07][severity - 1]
53
+
54
+ x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
55
+ return np.clip(x, 0, 1) * 255
56
+
57
+
58
+ def motion_blur(x, severity=1):
59
+ c = [(6, 1), (6, 1.5), (6, 2), (8, 2), (9, 2.5)][severity - 1]
60
+
61
+ output = BytesIO()
62
+ x.save(output, format="PNG")
63
+ x = MotionImage(blob=output.getvalue())
64
+
65
+ x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))
66
+
67
+ x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)
68
+
69
+ if x.shape != (32, 32):
70
+ return np.clip(x[..., [2, 1, 0]], 0, 255) # BGR to RGB
71
+ else: # greyscale to RGB
72
+ return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)
73
+
74
+
75
+ def spatter(x, severity=1):
76
+ c = [
77
+ (0.62, 0.1, 0.7, 0.7, 0.5, 0),
78
+ (0.65, 0.1, 0.8, 0.7, 0.5, 0),
79
+ (0.65, 0.3, 1, 0.69, 0.5, 0),
80
+ (0.65, 0.1, 0.7, 0.69, 0.6, 1),
81
+ (0.65, 0.1, 0.5, 0.68, 0.6, 1),
82
+ ][severity - 1]
83
+ x = np.array(x, dtype=np.float32) / 255.0
84
+
85
+ liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])
86
+
87
+ liquid_layer = gaussian(liquid_layer, sigma=c[2])
88
+ liquid_layer[liquid_layer < c[3]] = 0
89
+ if c[5] == 0:
90
+ liquid_layer = (liquid_layer * 255).astype(np.uint8)
91
+ dist = 255 - cv2.Canny(liquid_layer, 50, 150)
92
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
93
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
94
+ dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
95
+ dist = cv2.equalizeHist(dist)
96
+ # ker = np.array([[-1,-2,-3],[-2,0,0],[-3,0,1]], dtype=np.float32)
97
+ # ker -= np.mean(ker)
98
+ ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
99
+ dist = cv2.filter2D(dist, cv2.CV_8U, ker)
100
+ dist = cv2.blur(dist, (3, 3)).astype(np.float32)
101
+
102
+ m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
103
+ m /= np.max(m, axis=(0, 1))
104
+ m *= c[4]
105
+
106
+ # water is pale turqouise
107
+ color = np.concatenate(
108
+ (
109
+ 175 / 255.0 * np.ones_like(m[..., :1]),
110
+ 238 / 255.0 * np.ones_like(m[..., :1]),
111
+ 238 / 255.0 * np.ones_like(m[..., :1]),
112
+ ),
113
+ axis=2,
114
+ )
115
+
116
+ color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
117
+ x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)
118
+
119
+ return cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
120
+ else:
121
+ m = np.where(liquid_layer > c[3], 1, 0)
122
+ m = gaussian(m.astype(np.float32), sigma=c[4])
123
+ m[m < 0.8] = 0
124
+ # m = np.abs(m) ** (1/c[4])
125
+
126
+ # mud brown
127
+ color = np.concatenate(
128
+ (
129
+ 63 / 255.0 * np.ones_like(x[..., :1]),
130
+ 42 / 255.0 * np.ones_like(x[..., :1]),
131
+ 20 / 255.0 * np.ones_like(x[..., :1]),
132
+ ),
133
+ axis=2,
134
+ )
135
+
136
+ color *= m[..., np.newaxis]
137
+ x *= 1 - m[..., np.newaxis]
138
+
139
+ return np.clip(x + color, 0, 1) * 255
140
+
141
+
142
+ def contrast(x, severity=1):
143
+ c = [0.75, 0.5, 0.4, 0.3, 0.15][severity - 1]
144
+
145
+ x = np.array(x) / 255.0
146
+ means = np.mean(x, axis=(0, 1), keepdims=True)
147
+ return np.clip((x - means) * c + means, 0, 1) * 255
148
+
149
+
150
+ def jpeg_compression(x, severity=1):
151
+ c = [80, 65, 58, 50, 40][severity - 1]
152
+
153
+ output = BytesIO()
154
+ x.save(output, "JPEG", quality=c)
155
+ x = PILImage.open(output)
156
+
157
+ return x
158
+
159
+
160
+ def pixelate(x, severity=1):
161
+ c = [0.95, 0.9, 0.85, 0.75, 0.65][severity - 1]
162
+
163
+ x = x.resize((int(32 * c), int(32 * c)), PILImage.BOX)
164
+ x = x.resize((32, 32), PILImage.BOX)
165
+
166
+ return x
167
+
168
+
169
+ # /////////////// End Distortions ///////////////
170
+
171
+
172
+ distortion_methods = collections.OrderedDict()
173
+ distortion_methods["Gaussian Noise"] = gaussian_noise
174
+ distortion_methods["Impulse Noise"] = impulse_noise
175
+ distortion_methods["Motion Blur"] = motion_blur
176
+ distortion_methods["Contrast"] = contrast
177
+ distortion_methods["Pixelate"] = pixelate
178
+ distortion_methods["JPEG"] = jpeg_compression
179
+ distortion_methods["Spatter"] = spatter
@@ -20,7 +20,7 @@ class TransformedImageDataset(Dataset):
20
20
  transform (Callable): The transform to be applied to the images.
21
21
  """
22
22
 
23
- def __init__(self, dataset, transform: Callable):
23
+ def __init__(self, dataset: Dataset, transform: Callable):
24
24
  super().__init__()
25
25
  self.dataset = dataset
26
26
  self.transform = transform
@@ -1,6 +1,6 @@
1
1
  import fnmatch
2
2
  import os
3
- from typing import Callable, Optional
3
+ from typing import Callable, Dict, Optional, Tuple
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -68,7 +68,7 @@ class NYUv2(Dataset):
68
68
  )
69
69
  self.noise = torch.rand(self.data_len, 1, 288, 384)
70
70
 
71
- def __getitem__(self, index):
71
+ def __getitem__(self, index) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
72
72
  """
73
73
  Retrieve an item from the dataset.
74
74
 
@@ -37,11 +37,12 @@ _import_structure = {
37
37
  "ties_merging": ["TiesMergingAlgorithm"],
38
38
  "dare": ["DareSimpleAverage", "DareTaskArithmetic", "DareTiesMerging"],
39
39
  "fisher_merging": [
40
+ "FisherMergingAlgorithm",
40
41
  "FisherMergingForCLIPVisionModel",
41
42
  "FisherMergingAlgorithmForGPT2",
42
43
  ],
43
44
  "regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
44
- "regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
45
+ "regmean_plusplus": ["RegMeanAlgorithmPlusPlus", "RegMeanAlgorithmForCLIPPlusPlus"],
45
46
  "adamerging": [
46
47
  "CLIPTaskWiseAdaMergingAlgorithm",
47
48
  "CLIPLayerWiseAdaMergingAlgorithm",
@@ -69,6 +70,7 @@ _import_structure = {
69
70
  "FlanT5LayerWiseGossipAlgorithm",
70
71
  ],
71
72
  "fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
73
+ "tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
72
74
  # plug-and-play model merging methods
73
75
  "concrete_subspace": [
74
76
  "ConcreteTaskArithmeticAlgorithmForCLIP",
@@ -99,6 +101,8 @@ _import_structure = {
99
101
  "SmileUpscalingAlgorithm",
100
102
  "SingularProjectionMergingAlgorithm",
101
103
  ],
104
+ # task vector compression methods
105
+ "bitdelta": ["BitDeltaAlgorithm"],
102
106
  # pruning methods
103
107
  "pruning": [
104
108
  "MagnitudeDiffPruningAlgorithm",
@@ -126,6 +130,7 @@ if TYPE_CHECKING:
126
130
  from .adamerging import *
127
131
  from .analysis import TaskVectorCosSimilarity, TaskVectorViolinPlot
128
132
  from .base_algorithm import BaseAlgorithm, BaseModelFusionAlgorithm
133
+ from .bitdelta import BitDeltaAlgorithm
129
134
  from .classification import (
130
135
  ContinualImageClassificationFineTuningForCLIP,
131
136
  ImageClassificationFineTuningForCLIP,
@@ -154,7 +159,11 @@ if TYPE_CHECKING:
154
159
  LayerWisePruningForMixtral,
155
160
  ProgressivePruningForMixtral,
156
161
  )
157
- from .fisher_merging import FisherMergingForCLIPVisionModel
162
+ from .fisher_merging import (
163
+ FisherMergingAlgorithm,
164
+ FisherMergingAlgorithmForGPT2,
165
+ FisherMergingForCLIPVisionModel,
166
+ )
158
167
  from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
159
168
  from .gossip import (
160
169
  CLIPLayerWiseGossipAlgorithm,
@@ -196,7 +205,10 @@ if TYPE_CHECKING:
196
205
  )
197
206
  from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
198
207
  from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
199
- from .regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
208
+ from .regmean_plusplus import (
209
+ RegMeanAlgorithmForCLIPPlusPlus,
210
+ RegMeanAlgorithmPlusPlus,
211
+ )
200
212
  from .simple_average import SimpleAverageAlgorithm
201
213
  from .slerp import SlerpMergeAlgorithm
202
214
  from .smile_upscaling import (
@@ -212,6 +224,7 @@ if TYPE_CHECKING:
212
224
  PCPSparseLoForLlama,
213
225
  SparseLoForLlama,
214
226
  )
227
+ from .tall_mask import TallMaskTaskArithmeticAlgorithm
215
228
  from .task_arithmetic import TaskArithmeticAlgorithm
216
229
  from .task_singular_vector import TaskSingularVectorMerging
217
230
  from .ties_merging import TiesMergingAlgorithm
@@ -3,7 +3,7 @@ Example Usage:
3
3
 
4
4
  ```bash
5
5
  fusion_bench \
6
- method=adamerging \
6
+ method=adamerging/clip \
7
7
  method.name=clip_layer_wise_adamerging \
8
8
  method.save_merging_weights=merging_weights.pt \
9
9
  modelpool=clip-vit-base-patch32_TA8 \
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
+ from typing import Iterator
4
5
 
5
6
  import torch
6
7
  from omegaconf import DictConfig
@@ -42,7 +43,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
42
43
  super().__init__(algorithm_config)
43
44
 
44
45
  @functools.cache
45
- def get_test_dataset(self, task: str):
46
+ def get_test_dataset(self, task: str) -> CLIPDataset:
46
47
  """
47
48
  Load the test dataset for the task.
48
49
  This method is cached, so the dataset is loaded only once.
@@ -59,7 +60,7 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
59
60
  return dataset
60
61
 
61
62
  @functools.cache
62
- def get_shuffled_test_loader_iter(self, task: str):
63
+ def get_shuffled_test_loader_iter(self, task: str) -> Iterator:
63
64
  """
64
65
  Get an iterator over the shuffled test DataLoader for the task.
65
66
 
@@ -88,11 +89,14 @@ class CLIPTaskWiseAdaMergingAlgorithm(TaskWiseAdaMergingAlgorithm):
88
89
  classification head for each task.
89
90
  """
90
91
  clip_model_config = self.modelpool.get_model_config("_pretrained_")
91
- pretrained_path = (
92
- clip_model_config.pretrained_model_name_or_path
93
- if hasattr(clip_model_config, "pretrained_model_name_or_path")
94
- else clip_model_config.path
95
- )
92
+ if isinstance(clip_model_config, str):
93
+ pretrained_path = clip_model_config
94
+ else:
95
+ pretrained_path = (
96
+ clip_model_config.pretrained_model_name_or_path
97
+ if hasattr(clip_model_config, "pretrained_model_name_or_path")
98
+ else clip_model_config.path
99
+ )
96
100
 
97
101
  with timeit_context("Loading CLIP processor and pretrained CLIP model."):
98
102
  self._clip_processor = CLIPProcessor.from_pretrained(pretrained_path)
@@ -31,9 +31,9 @@ log = logging.getLogger(__name__)
31
31
 
32
32
 
33
33
  class LayerWiseAdaMergingAlgorithm(
34
- ModelFusionAlgorithm,
35
34
  LightningFabricMixin,
36
35
  SimpleProfilerMixin,
36
+ ModelFusionAlgorithm,
37
37
  ):
38
38
  _program: "FabricModelFusionProgram"
39
39
  """The program that this algorithm is running on."""
@@ -55,7 +55,9 @@ class LayerWiseAdaMergingAlgorithm(
55
55
  super().__init__(algorithm_config)
56
56
 
57
57
  @torch.no_grad()
58
- def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
58
+ def construct_layer_wise_merged_model(
59
+ self, modelpool: "ModelPool"
60
+ ) -> LayerWiseMergedModel:
59
61
  """
60
62
  Constructs a wrapped layer-wise merged model from model pool.
61
63
 
@@ -125,7 +127,7 @@ class LayerWiseAdaMergingAlgorithm(
125
127
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
126
128
  torch.save(merging_weights.detach().cpu(), save_path)
127
129
 
128
- def run(self, modelpool: ModelPool, **kwargs):
130
+ def run(self, modelpool: ModelPool, **kwargs) -> nn.Module:
129
131
  """
130
132
  Run the Layer-Wise AdaMerging Algorithm.
131
133
 
@@ -176,7 +178,9 @@ class LayerWiseAdaMergingAlgorithm(
176
178
  pass
177
179
 
178
180
  @abstractmethod
179
- def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
181
+ def compute_logits(
182
+ self, module: LayerWiseMergedModel, images: Tensor, task: str
183
+ ) -> Tensor:
180
184
  """
181
185
  Compute the logits for the given images and task.
182
186
 
@@ -190,7 +194,9 @@ class LayerWiseAdaMergingAlgorithm(
190
194
  """
191
195
  pass
192
196
 
193
- def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
197
+ def test_time_adaptation(
198
+ self, module: "LayerWiseMergedModel[TorchModelType]"
199
+ ) -> "LayerWiseMergedModel[TorchModelType]":
194
200
  """
195
201
  Perform test-time adaptation on the merged model.
196
202