fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -16,14 +16,18 @@ from transformers import CLIPVisionModel
16
16
  from transformers.models.clip.modeling_clip import CLIPEncoderLayer
17
17
  from typing_extensions import override
18
18
 
19
- from fusion_bench.method.base_algorithm import BaseAlgorithm
19
+ from fusion_bench import (
20
+ BaseAlgorithm,
21
+ auto_register_config,
22
+ print_parameters,
23
+ timeit_context,
24
+ )
25
+ from fusion_bench.dataset import CLIPDataset
20
26
  from fusion_bench.method.task_arithmetic import task_arithmetic_merge
21
27
  from fusion_bench.mixins.clip_classification import CLIPClassificationMixin
22
28
  from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
23
29
  from fusion_bench.modelpool import CLIPVisionModelPool
24
- from fusion_bench.utils import timeit_context
25
30
  from fusion_bench.utils.data import InfiniteDataLoader
26
- from fusion_bench.utils.parameters import print_parameters
27
31
 
28
32
  from .module import ParetoWeightEnsemblingModule
29
33
  from .utils import generate_simplex_grid
@@ -31,27 +35,13 @@ from .utils import generate_simplex_grid
31
35
  log = logging.getLogger(__name__)
32
36
 
33
37
 
38
+ @auto_register_config
34
39
  class PWEMoEAlgorithmForCLIP(
35
40
  BaseAlgorithm,
36
41
  SimpleProfilerMixin,
37
42
  CLIPClassificationMixin,
38
43
  ):
39
44
  modelpool: CLIPVisionModelPool = None
40
- _config_mapping = BaseAlgorithm._config_mapping | {
41
- "upscale_mlp": "upscale_mlp",
42
- "upscale_attn": "upscale_attn",
43
- "init_lambda": "init_lambda",
44
- "router_hidden_layers": "router_hidden_layers",
45
- "lr": "lr",
46
- "num_steps": "num_steps",
47
- "save_interval": "save_interval",
48
- "alpha": "alpha",
49
- "checkpoint_path": "checkpoint_path",
50
- "eval_grid": "eval_grid",
51
- "eval_grid_n": "eval_grid_n",
52
- "eval_grid_m": "eval_grid_m",
53
- "_dataloader_kwargs": "dataloader_kwargs",
54
- }
55
45
 
56
46
  def __init__(
57
47
  self,
@@ -72,19 +62,6 @@ class PWEMoEAlgorithmForCLIP(
72
62
  **kwargs,
73
63
  ):
74
64
  super().__init__(**kwargs)
75
- self.upscale_mlp = upscale_mlp
76
- self.upscale_attn = upscale_attn
77
- self.init_lambda = init_lambda
78
- self.router_hidden_layers = router_hidden_layers
79
- self.lr = lr
80
- self.num_steps = num_steps
81
- self.save_interval = save_interval
82
- self.alpha = alpha
83
- self.checkpoint_path = checkpoint_path
84
- self.eval_grid = eval_grid
85
- self.eval_grid_n = eval_grid_n
86
- self.eval_grid_m = eval_grid_m
87
- self._dataloader_kwargs = dataloader_kwargs
88
65
 
89
66
  @override
90
67
  def run(self, modelpool: CLIPVisionModelPool):
@@ -193,13 +170,14 @@ class PWEMoEAlgorithmForCLIP(
193
170
  Loads the datasets specified in the configuration.
194
171
  """
195
172
  train_datasets = {
196
- dataset_name: self.modelpool.load_train_dataset(
197
- dataset_name, self.clip_processor
173
+ dataset_name: CLIPDataset(
174
+ self.modelpool.load_train_dataset(dataset_name),
175
+ processor=self.clip_processor,
198
176
  )
199
177
  for dataset_name in self.modelpool.model_names
200
178
  }
201
179
  train_loaders = {
202
- dataset_name: DataLoader(dataset, shuffle=True, **self._dataloader_kwargs)
180
+ dataset_name: DataLoader(dataset, shuffle=True, **self.dataloader_kwargs)
203
181
  for dataset_name, dataset in train_datasets.items()
204
182
  }
205
183
  train_loaders = {
@@ -5,9 +5,7 @@ import torch
5
5
 
6
6
  from fusion_bench.modelpool import BaseModelPool
7
7
  from fusion_bench.utils.parameters import count_parameters
8
- from fusion_bench.utils.state_dict_arithmetic import (
9
- state_dict_mul,
10
- )
8
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_mul
11
9
 
12
10
  from .base_algorithm import SuperposedAlgorithmBase, compare_models
13
11
 
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
27
27
 
28
28
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
29
29
  super().__init__(**kwargs)
30
- self._dataloader_kwargs = dataloader_kwargs
30
+ self.dataloader_kwargs = dataloader_kwargs
31
31
 
32
32
  def on_regmean_start(self):
33
33
  self.setup_zero_shot_classification_head()
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
60
60
  # setup dataloader
61
61
  train_dataset = CLIPDataset(train_dataset, self.clip_processor)
62
62
  train_dataloader = DataLoader(
63
- train_dataset, shuffle=True, **self._dataloader_kwargs
63
+ train_dataset, shuffle=True, **self.dataloader_kwargs
64
64
  )
65
65
  train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
66
66
  model = self.fabric.setup(model)
@@ -15,7 +15,7 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
15
15
  from transformers.data import default_data_collator
16
16
  from transformers.models.gpt2.modeling_gpt2 import Conv1D
17
17
 
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
19
19
  from fusion_bench.utils import timeit_context
20
20
 
21
21
  from .regmean import RegMeanAlgorithm
@@ -23,22 +23,15 @@ from .regmean import RegMeanAlgorithm
23
23
  log = logging.getLogger(__name__)
24
24
 
25
25
 
26
+ @auto_register_config
26
27
  class RegMeanAlgorithmForGPT2(
27
- RegMeanAlgorithm,
28
28
  LightningFabricMixin,
29
+ RegMeanAlgorithm,
29
30
  ):
30
31
  _include_module_type = [Conv1D]
31
32
  classifiers = {}
32
- _config_mapping = RegMeanAlgorithm._config_mapping | {
33
- "cache_dir": "cache_dir",
34
- "batch_size": "batch_size",
35
- "num_workers": "num_workers",
36
- }
37
33
 
38
34
  def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
39
- self.cache_dir = cache_dir
40
- self.batch_size = batch_size
41
- self.num_workers = num_workers
42
35
  super().__init__(**kwargs)
43
36
 
44
37
  def on_regmean_start(self):
@@ -13,7 +13,7 @@ from torch import Tensor, nn
13
13
  from tqdm.autonotebook import tqdm
14
14
 
15
15
  from fusion_bench.method import BaseAlgorithm
16
- from fusion_bench.mixins import SimpleProfilerMixin
16
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
17
17
  from fusion_bench.modelpool import BaseModelPool
18
18
 
19
19
  log = logging.getLogger(__name__)
@@ -280,14 +280,9 @@ def regmean_merging(
280
280
  return merged_params
281
281
 
282
282
 
283
+ @auto_register_config
283
284
  class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
284
285
  _include_module_type = [nn.Linear]
285
- _config_mapping = {
286
- "num_regmean_examples": "num_regmean_examples",
287
- "exclude_param_names_regex": "exclude_param_names_regex",
288
- "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
289
- "weight_transpose": "weight_transpose",
290
- }
291
286
 
292
287
  def __init__(
293
288
  self,
@@ -298,10 +293,6 @@ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
298
293
  weight_transpose: bool,
299
294
  **kwargs,
300
295
  ):
301
- self.num_regmean_examples = num_regmean_examples
302
- self.exclude_param_names_regex = exclude_param_names_regex
303
- self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
304
- self.weight_transpose = weight_transpose
305
296
  super().__init__(**kwargs)
306
297
 
307
298
  def run(self, modelpool: BaseModelPool, **kwargs):
@@ -1,3 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .clip_regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
3
- from .regmean_plusplus import RegMeanAlgorithmPlusPlus
3
+ from .regmean_plusplus import RegMeanAlgorithmPlusPlus
@@ -28,7 +28,7 @@ class RegMeanAlgorithmForCLIPPlusPlus(
28
28
 
29
29
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
30
30
  super().__init__(**kwargs)
31
- self._dataloader_kwargs = dataloader_kwargs
31
+ self.dataloader_kwargs = dataloader_kwargs
32
32
 
33
33
  def on_regmean_start(self):
34
34
  self.setup_zero_shot_classification_head()
@@ -125,27 +125,26 @@ class RegMeanAlgorithmForCLIPPlusPlus(
125
125
 
126
126
  param_dict = {}
127
127
  for name, param in model_to_merge_state_dict.items():
128
- if name.startswith("vision_model.embeddings") or name.startswith("vision_model.pre_layrnorm"):
128
+ if name.startswith("vision_model.embeddings") or name.startswith(
129
+ "vision_model.pre_layrnorm"
130
+ ):
129
131
  param_dict[name] = param
130
132
 
131
133
  for param_name in param_dict.keys():
132
- models_to_merge_param_dict[param_name].append(
133
- param_dict[param_name]
134
- )
134
+ models_to_merge_param_dict[param_name].append(param_dict[param_name])
135
135
 
136
136
  # merge the parameters of the embedding layer
137
137
  merged_params_dict = {}
138
138
  for param_name, param_list in models_to_merge_param_dict.items():
139
139
  merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
140
-
140
+
141
141
  return merged_params_dict
142
-
143
-
142
+
144
143
  def get_input_for_first_layer(self, model: nn.Module, train_dataset):
145
144
  # setup dataloader
146
145
  train_dataset = CLIPDataset(train_dataset, self.clip_processor)
147
146
  train_dataloader = DataLoader(
148
- train_dataset, shuffle=True, **self._dataloader_kwargs
147
+ train_dataset, shuffle=True, **self.dataloader_kwargs
149
148
  )
150
149
  train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
151
150
  model = self.fabric.setup(model)
@@ -157,9 +156,9 @@ class RegMeanAlgorithmForCLIPPlusPlus(
157
156
  image_embeds = model.vision_model.embeddings(images)
158
157
  image_embeds = model.vision_model.pre_layrnorm(image_embeds)
159
158
  image_embeds = image_embeds.detach().cpu()
160
-
159
+
161
160
  return image_embeds
162
-
161
+
163
162
  num_computed_examples = 0
164
163
  num_regmean_examples = self.num_regmean_examples
165
164
 
@@ -169,24 +168,32 @@ class RegMeanAlgorithmForCLIPPlusPlus(
169
168
  break
170
169
  batches_input.append(compute_input(model, batch))
171
170
  num_computed_examples += batch[0].size(0)
172
-
171
+
173
172
  return batches_input
174
173
 
175
174
  def get_layers(self, model: nn.Module):
176
175
  return model.vision_model.encoder.layers
177
-
178
- def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
176
+
177
+ def update_merged_params_dict(
178
+ self, merged_params_dict, new_merged_params, layer_idx
179
+ ):
179
180
  for key, value in new_merged_params.items():
180
181
  key = f"vision_model.encoder.layers.{layer_idx}.{key}"
181
182
  merged_params_dict[key] = value
182
183
 
183
184
  return merged_params_dict
184
-
185
- def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]) -> Tensor:
185
+
186
+ def layer_batches_forward(
187
+ self, layer: nn.Module, batches_input: List[Tensor]
188
+ ) -> Tensor:
186
189
  batches_output = []
187
190
  for batch in batches_input:
188
191
  device = next(layer.parameters()).device
189
192
  batch = batch.to(device)
190
- logits = layer(batch, attention_mask=None, causal_attention_mask=None)[0].detach().cpu()
193
+ logits = (
194
+ layer(batch, attention_mask=None, causal_attention_mask=None)[0]
195
+ .detach()
196
+ .cpu()
197
+ )
191
198
  batches_output.append(logits)
192
199
  return batches_output
@@ -81,13 +81,11 @@ def regmean_params_merge(
81
81
  reduce_non_diagonal_ratio: float = 1.0,
82
82
  weight_transpose: bool = True,
83
83
  module_name: str = "",
84
- device = "cpu"
84
+ device="cpu",
85
85
  ):
86
86
  # two lists with length num_models_to_merge
87
87
  param_multiplied_results, module_regmean_weights_list = [], []
88
- for model_idx, module_regmean_weights in enumerate(
89
- param_regmean_list
90
- ):
88
+ for model_idx, module_regmean_weights in enumerate(param_regmean_list):
91
89
  # reduce non-diagonal elements
92
90
  module_regmean_weights = reduce_non_diagonal_elements(
93
91
  regmean_weights=module_regmean_weights,
@@ -113,9 +111,7 @@ def regmean_params_merge(
113
111
  sum_param_multiplied_results = sum(param_multiplied_results)
114
112
 
115
113
  # get the inverse matrix
116
- inv_sum_module_regmean_weights = torch.inverse(
117
- sum_module_regmean_weights
118
- )
114
+ inv_sum_module_regmean_weights = torch.inverse(sum_module_regmean_weights)
119
115
  # merge parameters with regmean
120
116
  merged_param = torch.matmul(
121
117
  inv_sum_module_regmean_weights, sum_param_multiplied_results
@@ -158,15 +154,19 @@ def merging_with_regmean_weights(
158
154
  device = param_value_list[model_idx].device
159
155
 
160
156
  # Tensor, shape (hidden_dim, hidden_dim)
161
- module_regmean_weights = model_to_merge_regmean_weights[module_name].to(device)
157
+ module_regmean_weights = model_to_merge_regmean_weights[
158
+ module_name
159
+ ].to(device)
162
160
  module_regmean_weights_list.append(module_regmean_weights)
163
161
 
164
- merged_params[param_name] = regmean_params_merge(param_weight_list=param_value_list,
165
- param_regmean_list=module_regmean_weights_list,
166
- reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
167
- weight_transpose=weight_transpose,
168
- module_name=module_name,
169
- device=device)
162
+ merged_params[param_name] = regmean_params_merge(
163
+ param_weight_list=param_value_list,
164
+ param_regmean_list=module_regmean_weights_list,
165
+ reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
166
+ weight_transpose=weight_transpose,
167
+ module_name=module_name,
168
+ device=device,
169
+ )
170
170
 
171
171
  merged_by_regmean = True
172
172
  # use average merging for parameters whose names are not end with ".weight" or not in Linear module
@@ -205,7 +205,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
205
205
  modelpool = BaseModelPool(modelpool)
206
206
  self.modelpool = modelpool
207
207
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
208
- models_to_merge_dict = {name: model.to(device) for name, model in modelpool.named_models()}
208
+ models_to_merge_dict = {
209
+ name: model.to(device) for name, model in modelpool.named_models()
210
+ }
209
211
  self.on_regmean_start()
210
212
 
211
213
  # initialize the merged models as the pretrained model
@@ -213,7 +215,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
213
215
  merged_params_dict = {}
214
216
 
215
217
  # 1. merge embedding layer
216
- merged_embedding_dict = self.merge_embedding_layer(models_to_merge_dict=models_to_merge_dict)
218
+ merged_embedding_dict = self.merge_embedding_layer(
219
+ models_to_merge_dict=models_to_merge_dict
220
+ )
217
221
  merged_model.load_state_dict(merged_embedding_dict, strict=False)
218
222
 
219
223
  with torch.no_grad():
@@ -223,12 +227,13 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
223
227
  self.profile("computing first layer input"),
224
228
  ):
225
229
  batches_input_dict = defaultdict(list)
226
- for name in tqdm(models_to_merge_dict.keys(), desc="computing input for first layer"):
230
+ for name in tqdm(
231
+ models_to_merge_dict.keys(), desc="computing input for first layer"
232
+ ):
227
233
  dataset = modelpool.load_train_dataset(name)
228
-
234
+
229
235
  batches_input_dict[name] = self.get_input_for_first_layer(
230
- merged_model,
231
- dataset
236
+ merged_model, dataset
232
237
  )
233
238
 
234
239
  # 2. iteratively merge layer by layer with regmean algorithm
@@ -240,9 +245,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
240
245
  models_to_merge_layers_dict[name] = self.get_layers(model)
241
246
 
242
247
  param_names_to_merge = None
243
- for layer_idx, backbone_layer in tqdm(enumerate(backbone_layers),
244
- desc="merging layers",
245
- total=num_layers):
248
+ for layer_idx, backbone_layer in tqdm(
249
+ enumerate(backbone_layers), desc="merging layers", total=num_layers
250
+ ):
246
251
  # dictionary of list, where key is the parameter name,
247
252
  # value is a list of the corresponding parameters of all the models that need to be merged
248
253
  models_to_merge_param_dict = defaultdict(list)
@@ -263,16 +268,19 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
263
268
  "exclude_param_names_regex", []
264
269
  ),
265
270
  )
266
-
271
+
267
272
  for param_name in param_names_to_merge:
268
273
  models_to_merge_param_dict[param_name].append(
269
274
  param_dict[param_name]
270
275
  )
271
276
 
272
277
  linear_modules_to_merge = get_modules_to_merge(
273
- model=layer_to_merge, include_module_types=self._include_module_type
278
+ model=layer_to_merge,
279
+ include_module_types=self._include_module_type,
274
280
  )
275
- assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
281
+ assert (
282
+ len(linear_modules_to_merge) > 0
283
+ ), "No linear modules to merge"
276
284
 
277
285
  # 2.1. compute regmean weights for each model
278
286
  with (
@@ -288,12 +296,19 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
288
296
 
289
297
  module_subset = get_param_names_to_merge(
290
298
  input_param_names=list(param_dict.keys()),
291
- exclude_param_names_regex=self.exclude_param_names_regex
299
+ exclude_param_names_regex=self.exclude_param_names_regex,
292
300
  )
293
- module_subset = [name.replace(".weight", "").replace(".bias", "") for name in module_subset]
301
+ module_subset = [
302
+ name.replace(".weight", "").replace(".bias", "")
303
+ for name in module_subset
304
+ ]
294
305
  module_subset = list(set(module_subset))
295
- regmean_weights = {module_name: regmean_weights[module_name] for module_name in module_subset if module_name in regmean_weights}
296
-
306
+ regmean_weights = {
307
+ module_name: regmean_weights[module_name]
308
+ for module_name in module_subset
309
+ if module_name in regmean_weights
310
+ }
311
+
297
312
  models_to_merge_regmean_weights_list.append(regmean_weights)
298
313
 
299
314
  # 2.2. merge parameters with regmean weights
@@ -318,21 +333,22 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
318
333
  self.profile("forwarding next layer"),
319
334
  ):
320
335
  if layer_idx < num_layers - 1:
321
- backbone_layer.load_state_dict(merged_layer_params, strict=False)
336
+ backbone_layer.load_state_dict(
337
+ merged_layer_params, strict=False
338
+ )
322
339
  batches_output_dict = defaultdict(list)
323
340
  for name in models_to_merge_dict.keys():
324
341
  batches_output_dict[name] = self.layer_batches_forward(
325
- backbone_layer,
326
- batches_input_dict[name]
342
+ backbone_layer, batches_input_dict[name]
327
343
  )
328
344
  batches_input_dict = batches_output_dict
329
-
345
+
330
346
  # 3. load state dict to the merged model
331
347
  merged_model.load_state_dict(merged_params_dict, strict=False)
332
348
 
333
349
  self.print_profile_summary()
334
350
  return merged_model
335
-
351
+
336
352
  def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
337
353
  """
338
354
  Merge the embedding layer of the model with the merged model.
@@ -345,10 +361,12 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
345
361
 
346
362
  def get_layers(self, model: nn.Module):
347
363
  raise NotImplementedError
348
-
349
- def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
364
+
365
+ def update_merged_params_dict(
366
+ self, merged_params_dict, new_merged_params, layer_idx
367
+ ):
350
368
  raise NotImplementedError
351
-
369
+
352
370
  def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
353
371
  raise NotImplementedError
354
372
 
@@ -6,7 +6,7 @@ import torch
6
6
  from torch import nn
7
7
 
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
- from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
9
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
10
10
  from fusion_bench.modelpool import BaseModelPool
11
11
  from fusion_bench.utils import LazyStateDict
12
12
  from fusion_bench.utils.state_dict_arithmetic import (
@@ -59,24 +59,20 @@ def simple_average(
59
59
  return state_dict_avg(modules)
60
60
 
61
61
 
62
+ @auto_register_config
62
63
  class SimpleAverageAlgorithm(
63
- BaseAlgorithm,
64
64
  SimpleProfilerMixin,
65
+ BaseAlgorithm,
65
66
  ):
66
- _config_mapping = BaseAlgorithm._config_mapping | {
67
- "show_pbar": "show_pbar",
68
- }
69
-
70
- def __init__(self, show_pbar: bool = False):
67
+ def __init__(self, show_pbar: bool = False, **kwargs):
71
68
  """
72
69
  Args:
73
70
  show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
74
71
  """
75
- super().__init__()
76
- self.show_pbar = show_pbar
72
+ super().__init__(**kwargs)
77
73
 
78
74
  @torch.no_grad()
79
- def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
75
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
80
76
  """
81
77
  Fuse the models in the given model pool using simple averaging.
82
78
 
@@ -124,13 +120,13 @@ class SimpleAverageAlgorithm(
124
120
  if isinstance(forward_model, LazyStateDict):
125
121
  # if the model is a LazyStateDict, convert it to an empty module
126
122
  forward_model = forward_model.meta_module.to_empty(
127
- device=(
128
- "cpu"
129
- if forward_model._torch_dtype is None
130
- else forward_model._torch_dtype
131
- )
123
+ device=forward_model._device
132
124
  )
133
- forward_model.load_state_dict(sd)
125
+ result = forward_model.load_state_dict(sd, strict=False)
126
+ if result.unexpected_keys:
127
+ raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
128
+ if result.missing_keys:
129
+ log.warning(f"Missing keys in state dict: {result.missing_keys}")
134
130
  # print profile report and log the merged models
135
131
  self.print_profile_summary()
136
132
  log.info(f"merged {len(merged_model_names)} models:")
@@ -1,10 +1,13 @@
1
1
  import logging
2
+ from typing import Any, Dict
2
3
 
3
4
  import torch
5
+ from torch import nn
4
6
  from typing_extensions import override
5
7
 
6
8
  from fusion_bench.method import BaseAlgorithm
7
9
  from fusion_bench.modelpool import BaseModelPool
10
+ from fusion_bench.utils.type import StateDictType
8
11
 
9
12
  from .slerp_utils import slerp
10
13
 
@@ -18,7 +21,7 @@ def slerp_on_state_dicts(
18
21
  *,
19
22
  DOT_THRESHOLD: float = 0.9995,
20
23
  epsilon: float = 1e-8,
21
- ):
24
+ ) -> StateDictType:
22
25
  """
23
26
  Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
24
27
 
@@ -72,7 +75,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
72
75
  super().__init__()
73
76
 
74
77
  @override
75
- def run(self, modelpool: BaseModelPool):
78
+ def run(self, modelpool: BaseModelPool) -> nn.Module:
76
79
  """
77
80
  Run the SlerpMergeAlgorithm on the given model pool.
78
81