fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
27
27
 
28
28
  def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
29
29
  super().__init__(**kwargs)
30
- self._dataloader_kwargs = dataloader_kwargs
30
+ self.dataloader_kwargs = dataloader_kwargs
31
31
 
32
32
  def on_regmean_start(self):
33
33
  self.setup_zero_shot_classification_head()
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
60
60
  # setup dataloader
61
61
  train_dataset = CLIPDataset(train_dataset, self.clip_processor)
62
62
  train_dataloader = DataLoader(
63
- train_dataset, shuffle=True, **self._dataloader_kwargs
63
+ train_dataset, shuffle=True, **self.dataloader_kwargs
64
64
  )
65
65
  train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
66
66
  model = self.fabric.setup(model)
@@ -15,7 +15,7 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
15
15
  from transformers.data import default_data_collator
16
16
  from transformers.models.gpt2.modeling_gpt2 import Conv1D
17
17
 
18
- from fusion_bench.mixins import LightningFabricMixin
18
+ from fusion_bench.mixins import LightningFabricMixin, auto_register_config
19
19
  from fusion_bench.utils import timeit_context
20
20
 
21
21
  from .regmean import RegMeanAlgorithm
@@ -23,22 +23,15 @@ from .regmean import RegMeanAlgorithm
23
23
  log = logging.getLogger(__name__)
24
24
 
25
25
 
26
+ @auto_register_config
26
27
  class RegMeanAlgorithmForGPT2(
27
- RegMeanAlgorithm,
28
28
  LightningFabricMixin,
29
+ RegMeanAlgorithm,
29
30
  ):
30
31
  _include_module_type = [Conv1D]
31
32
  classifiers = {}
32
- _config_mapping = RegMeanAlgorithm._config_mapping | {
33
- "cache_dir": "cache_dir",
34
- "batch_size": "batch_size",
35
- "num_workers": "num_workers",
36
- }
37
33
 
38
34
  def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
39
- self.cache_dir = cache_dir
40
- self.batch_size = batch_size
41
- self.num_workers = num_workers
42
35
  super().__init__(**kwargs)
43
36
 
44
37
  def on_regmean_start(self):
@@ -13,7 +13,7 @@ from torch import Tensor, nn
13
13
  from tqdm.autonotebook import tqdm
14
14
 
15
15
  from fusion_bench.method import BaseAlgorithm
16
- from fusion_bench.mixins import SimpleProfilerMixin
16
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
17
17
  from fusion_bench.modelpool import BaseModelPool
18
18
 
19
19
  log = logging.getLogger(__name__)
@@ -280,14 +280,9 @@ def regmean_merging(
280
280
  return merged_params
281
281
 
282
282
 
283
+ @auto_register_config
283
284
  class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
284
285
  _include_module_type = [nn.Linear]
285
- _config_mapping = {
286
- "num_regmean_examples": "num_regmean_examples",
287
- "exclude_param_names_regex": "exclude_param_names_regex",
288
- "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
289
- "weight_transpose": "weight_transpose",
290
- }
291
286
 
292
287
  def __init__(
293
288
  self,
@@ -298,10 +293,6 @@ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
298
293
  weight_transpose: bool,
299
294
  **kwargs,
300
295
  ):
301
- self.num_regmean_examples = num_regmean_examples
302
- self.exclude_param_names_regex = exclude_param_names_regex
303
- self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
304
- self.weight_transpose = weight_transpose
305
296
  super().__init__(**kwargs)
306
297
 
307
298
  def run(self, modelpool: BaseModelPool, **kwargs):
@@ -0,0 +1,3 @@
1
+ # flake8: noqa F401
2
+ from .clip_regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
3
+ from .regmean_plusplus import RegMeanAlgorithmPlusPlus
@@ -0,0 +1,199 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Dict, List, cast # noqa: F401
4
+
5
+ import torch
6
+ import torch.utils.data
7
+ from omegaconf import DictConfig
8
+ from torch import Tensor, nn
9
+ from torch.nn.modules import Module
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.autonotebook import tqdm
12
+
13
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
14
+ from fusion_bench.mixins import CLIPClassificationMixin
15
+
16
+ from .regmean_plusplus import RegMeanAlgorithmPlusPlus
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ class RegMeanAlgorithmForCLIPPlusPlus(
22
+ RegMeanAlgorithmPlusPlus,
23
+ CLIPClassificationMixin,
24
+ ):
25
+ _config_mapping = {
26
+ "_dataloader_kwargs": "dataloader_kwargs",
27
+ }
28
+
29
+ def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
30
+ super().__init__(**kwargs)
31
+ self.dataloader_kwargs = dataloader_kwargs
32
+
33
+ def on_regmean_start(self):
34
+ self.setup_zero_shot_classification_head()
35
+
36
+ def compute_logits(self, module, batch, task: str) -> Tensor:
37
+ images, _ = batch
38
+ text_embeds = self.zeroshot_weights[task]
39
+
40
+ image_embeds = module(images)[1]
41
+ image_embeds = self.visual_projection(image_embeds)
42
+
43
+ # normalize embeddings
44
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
45
+
46
+ # cosine similarity
47
+ logits_per_text = (
48
+ torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
49
+ )
50
+ logits_per_image = logits_per_text.t()
51
+
52
+ return logits_per_image
53
+
54
+ def get_regmean_weights(
55
+ self,
56
+ model_name: str,
57
+ layer: Module,
58
+ batches_input: List[Tensor],
59
+ linear_modules_to_merge: Dict[str, Module],
60
+ ):
61
+ layer = self.fabric.setup(layer)
62
+
63
+ def compute_regmean_weights(module_name: str):
64
+ """
65
+ compute the regmean weights, a hook function to deal with each module's input
66
+ :param module_name: str, module name
67
+ :return:
68
+ """
69
+
70
+ def hook(module: nn.Module, input: tuple, output: torch.Tensor):
71
+ # Tensor, shape (batch_size, sequence_length, hidden_dim)
72
+ x = cast(Tensor, input[0]).detach()
73
+ batch_num_actual_examples = x.shape[0]
74
+ # Tensor, shape (batch_size * sequence_length, hidden_dim)
75
+ x = x.reshape(-1, x.shape[-1])
76
+ # Tensor, shape (hidden_dim, hidden_dim)
77
+ xtx = torch.matmul(x.transpose(0, 1), x)
78
+ # store the averaged weights in regmean_weights
79
+ if module_name not in regmean_weights.keys():
80
+ regmean_weights[module_name] = xtx / x.shape[0]
81
+ num_computed_examples[module_name] = x.shape[0]
82
+ num_actual_examples[module_name] = batch_num_actual_examples
83
+ else:
84
+ regmean_weights[module_name] = (
85
+ regmean_weights[module_name]
86
+ * num_computed_examples[module_name]
87
+ + xtx
88
+ ) / (num_computed_examples[module_name] + x.shape[0])
89
+ num_computed_examples[module_name] += x.shape[0]
90
+ num_actual_examples[module_name] += batch_num_actual_examples
91
+
92
+ return hook
93
+
94
+ handles = []
95
+ # dictionary, regmean matrices for each linear module inputs
96
+ regmean_weights = {}
97
+ # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
98
+ num_computed_examples = {}
99
+ # dictionary, number of actual examples used for computing regmean matrices
100
+ num_actual_examples = {}
101
+
102
+ for module_name, linear_module_to_merge in linear_modules_to_merge.items():
103
+ # register a hook in the forward process
104
+ handle = linear_module_to_merge.register_forward_hook(
105
+ compute_regmean_weights(module_name=module_name)
106
+ )
107
+ handles.append(handle)
108
+ _ = self.layer_batches_forward(layer, batches_input)
109
+
110
+ # remove the added hook
111
+ for handle in handles:
112
+ handle.remove()
113
+
114
+ for module_name in regmean_weights.keys():
115
+ regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
116
+
117
+ return regmean_weights
118
+
119
+ def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
120
+ models_to_merge_param_dict = defaultdict(list)
121
+
122
+ # get the parameters of the embedding layer from each model
123
+ for model_to_merge in models_to_merge_dict.values():
124
+ model_to_merge_state_dict = model_to_merge.state_dict()
125
+
126
+ param_dict = {}
127
+ for name, param in model_to_merge_state_dict.items():
128
+ if name.startswith("vision_model.embeddings") or name.startswith(
129
+ "vision_model.pre_layrnorm"
130
+ ):
131
+ param_dict[name] = param
132
+
133
+ for param_name in param_dict.keys():
134
+ models_to_merge_param_dict[param_name].append(param_dict[param_name])
135
+
136
+ # merge the parameters of the embedding layer
137
+ merged_params_dict = {}
138
+ for param_name, param_list in models_to_merge_param_dict.items():
139
+ merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
140
+
141
+ return merged_params_dict
142
+
143
+ def get_input_for_first_layer(self, model: nn.Module, train_dataset):
144
+ # setup dataloader
145
+ train_dataset = CLIPDataset(train_dataset, self.clip_processor)
146
+ train_dataloader = DataLoader(
147
+ train_dataset, shuffle=True, **self.dataloader_kwargs
148
+ )
149
+ train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
150
+ model = self.fabric.setup(model)
151
+
152
+ def compute_input(model, batch):
153
+ images, _ = batch
154
+
155
+ images = images.to(model.device)
156
+ image_embeds = model.vision_model.embeddings(images)
157
+ image_embeds = model.vision_model.pre_layrnorm(image_embeds)
158
+ image_embeds = image_embeds.detach().cpu()
159
+
160
+ return image_embeds
161
+
162
+ num_computed_examples = 0
163
+ num_regmean_examples = self.num_regmean_examples
164
+
165
+ batches_input = []
166
+ for batch in train_dataloader:
167
+ if num_computed_examples >= num_regmean_examples:
168
+ break
169
+ batches_input.append(compute_input(model, batch))
170
+ num_computed_examples += batch[0].size(0)
171
+
172
+ return batches_input
173
+
174
+ def get_layers(self, model: nn.Module):
175
+ return model.vision_model.encoder.layers
176
+
177
+ def update_merged_params_dict(
178
+ self, merged_params_dict, new_merged_params, layer_idx
179
+ ):
180
+ for key, value in new_merged_params.items():
181
+ key = f"vision_model.encoder.layers.{layer_idx}.{key}"
182
+ merged_params_dict[key] = value
183
+
184
+ return merged_params_dict
185
+
186
+ def layer_batches_forward(
187
+ self, layer: nn.Module, batches_input: List[Tensor]
188
+ ) -> Tensor:
189
+ batches_output = []
190
+ for batch in batches_input:
191
+ device = next(layer.parameters()).device
192
+ batch = batch.to(device)
193
+ logits = (
194
+ layer(batch, attention_mask=None, causal_attention_mask=None)[0]
195
+ .detach()
196
+ .cpu()
197
+ )
198
+ batches_output.append(logits)
199
+ return batches_output
@@ -0,0 +1,383 @@
1
+ import logging
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import Dict, List, cast
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from tqdm.autonotebook import tqdm
9
+
10
+ from fusion_bench.method import BaseAlgorithm
11
+ from fusion_bench.mixins import SimpleProfilerMixin
12
+ from fusion_bench.modelpool import BaseModelPool
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def get_param_names_to_merge(
18
+ input_param_names: List[str], exclude_param_names_regex: list
19
+ ):
20
+ """
21
+ get the names of parameters that need to be merged
22
+ :param input_param_names: list, names of input parameters
23
+ :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
24
+ :return:
25
+ """
26
+ param_names_to_merge = []
27
+ for param_name in input_param_names:
28
+ exclude = any(
29
+ [
30
+ re.match(exclude_pattern, param_name)
31
+ for exclude_pattern in exclude_param_names_regex
32
+ ]
33
+ )
34
+ if not exclude:
35
+ param_names_to_merge.append(param_name)
36
+ return param_names_to_merge
37
+
38
+
39
+ def get_modules_to_merge(model: nn.Module, include_module_types: list):
40
+ """
41
+ get the model modules that need to be merged, whose type is in include_module_types
42
+ :param model: nn.Module, input model
43
+ :param include_module_types: list, module types that want to include
44
+ :return:
45
+ """
46
+ modules_to_merge: Dict[str, nn.Module] = {}
47
+ for module_name, module in model.named_modules():
48
+ is_valid_type = not include_module_types or any(
49
+ [
50
+ isinstance(module, include_module_type)
51
+ for include_module_type in include_module_types
52
+ ]
53
+ )
54
+ if is_valid_type:
55
+ modules_to_merge[module_name] = module
56
+ return modules_to_merge
57
+
58
+
59
+ def reduce_non_diagonal_elements(
60
+ regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
61
+ ):
62
+ """
63
+ reduce the non-diagonal elements in regmean_weights
64
+ :param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
65
+ :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
66
+ :return:
67
+ """
68
+ # diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
69
+ diag_weights = torch.diag(
70
+ torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
71
+ ).to(regmean_weights.device)
72
+ # matrix with reduce_non_diagonal_ratio as elements
73
+ non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
74
+ # diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
75
+ return regmean_weights * (diag_weights + non_diag_weights)
76
+
77
+
78
+ def regmean_params_merge(
79
+ param_weight_list: List[Tensor],
80
+ param_regmean_list: List[Tensor],
81
+ reduce_non_diagonal_ratio: float = 1.0,
82
+ weight_transpose: bool = True,
83
+ module_name: str = "",
84
+ device="cpu",
85
+ ):
86
+ # two lists with length num_models_to_merge
87
+ param_multiplied_results, module_regmean_weights_list = [], []
88
+ for model_idx, module_regmean_weights in enumerate(param_regmean_list):
89
+ # reduce non-diagonal elements
90
+ module_regmean_weights = reduce_non_diagonal_elements(
91
+ regmean_weights=module_regmean_weights,
92
+ reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
93
+ )
94
+ module_regmean_weights_list.append(module_regmean_weights)
95
+
96
+ model_to_merge_param = param_weight_list[model_idx]
97
+ # since the weight shape of Linear module is (output_size, input_size), we need to transpose it
98
+ param_multiplied_results.append(
99
+ torch.matmul(
100
+ module_regmean_weights,
101
+ (
102
+ model_to_merge_param.transpose(0, 1)
103
+ if weight_transpose
104
+ else model_to_merge_param
105
+ ),
106
+ )
107
+ )
108
+
109
+ # sum up module_regmean_weights and param_multiplied_results over all individual models
110
+ sum_module_regmean_weights = sum(module_regmean_weights_list)
111
+ sum_param_multiplied_results = sum(param_multiplied_results)
112
+
113
+ # get the inverse matrix
114
+ inv_sum_module_regmean_weights = torch.inverse(sum_module_regmean_weights)
115
+ # merge parameters with regmean
116
+ merged_param = torch.matmul(
117
+ inv_sum_module_regmean_weights, sum_param_multiplied_results
118
+ )
119
+ # transpose to the original shape of "weight" in Linear module
120
+ merged_param = merged_param.transpose(0, 1) if weight_transpose else merged_param
121
+
122
+ return merged_param
123
+
124
+
125
+ def merging_with_regmean_weights(
126
+ models_to_merge_param_dict: dict,
127
+ models_to_merge_regmean_weights_list: list,
128
+ reduce_non_diagonal_ratio: float = 1.0,
129
+ weight_transpose: bool = True,
130
+ ):
131
+ """
132
+ merge parameters of different models with computed regmean weights
133
+ :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
134
+ value is a list of the corresponding parameters of all the models that need to be merged
135
+ :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
136
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
137
+ :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
138
+ :return:
139
+ """
140
+ # dict, dictionary of model parameters
141
+ merged_params = {}
142
+
143
+ for param_name, param_value_list in models_to_merge_param_dict.items():
144
+ merged_by_regmean = False
145
+ # only perform regmean merging on the "weight" parameter of Linear module
146
+ if param_name.endswith(".weight"):
147
+ module_name = param_name[: -len(".weight")]
148
+ if module_name in models_to_merge_regmean_weights_list[0].keys():
149
+ # two lists with length num_models_to_merge
150
+ module_regmean_weights_list = []
151
+ for model_idx, model_to_merge_regmean_weights in enumerate(
152
+ models_to_merge_regmean_weights_list
153
+ ):
154
+ device = param_value_list[model_idx].device
155
+
156
+ # Tensor, shape (hidden_dim, hidden_dim)
157
+ module_regmean_weights = model_to_merge_regmean_weights[
158
+ module_name
159
+ ].to(device)
160
+ module_regmean_weights_list.append(module_regmean_weights)
161
+
162
+ merged_params[param_name] = regmean_params_merge(
163
+ param_weight_list=param_value_list,
164
+ param_regmean_list=module_regmean_weights_list,
165
+ reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
166
+ weight_transpose=weight_transpose,
167
+ module_name=module_name,
168
+ device=device,
169
+ )
170
+
171
+ merged_by_regmean = True
172
+ # use average merging for parameters whose names are not end with ".weight" or not in Linear module
173
+ if not merged_by_regmean:
174
+ merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
175
+
176
+ return merged_params
177
+
178
+
179
+ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
180
+ _include_module_type = [nn.Linear]
181
+ _config_mapping = {
182
+ "num_regmean_examples": "num_regmean_examples",
183
+ "exclude_param_names_regex": "exclude_param_names_regex",
184
+ "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
185
+ "weight_transpose": "weight_transpose",
186
+ }
187
+
188
+ def __init__(
189
+ self,
190
+ *,
191
+ num_regmean_examples: int,
192
+ exclude_param_names_regex: list,
193
+ reduce_non_diagonal_ratio: float,
194
+ weight_transpose: bool,
195
+ **kwargs,
196
+ ):
197
+ self.num_regmean_examples = num_regmean_examples
198
+ self.exclude_param_names_regex = exclude_param_names_regex
199
+ self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
200
+ self.weight_transpose = weight_transpose
201
+ super().__init__(**kwargs)
202
+
203
+ def run(self, modelpool: BaseModelPool, **kwargs):
204
+ if not isinstance(modelpool, BaseModelPool):
205
+ modelpool = BaseModelPool(modelpool)
206
+ self.modelpool = modelpool
207
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
208
+ models_to_merge_dict = {
209
+ name: model.to(device) for name, model in modelpool.named_models()
210
+ }
211
+ self.on_regmean_start()
212
+
213
+ # initialize the merged models as the pretrained model
214
+ merged_model = modelpool.load_pretrained_model().to(device)
215
+ merged_params_dict = {}
216
+
217
+ # 1. merge embedding layer
218
+ merged_embedding_dict = self.merge_embedding_layer(
219
+ models_to_merge_dict=models_to_merge_dict
220
+ )
221
+ merged_model.load_state_dict(merged_embedding_dict, strict=False)
222
+
223
+ with torch.no_grad():
224
+ # 1.1. compute input for the first layer
225
+ with (
226
+ self.profile("merging models"),
227
+ self.profile("computing first layer input"),
228
+ ):
229
+ batches_input_dict = defaultdict(list)
230
+ for name in tqdm(
231
+ models_to_merge_dict.keys(), desc="computing input for first layer"
232
+ ):
233
+ dataset = modelpool.load_train_dataset(name)
234
+
235
+ batches_input_dict[name] = self.get_input_for_first_layer(
236
+ merged_model, dataset
237
+ )
238
+
239
+ # 2. iteratively merge layer by layer with regmean algorithm
240
+ backbone_layers = self.get_layers(merged_model)
241
+ num_layers = len(backbone_layers)
242
+
243
+ models_to_merge_layers_dict = defaultdict(list)
244
+ for name, model in models_to_merge_dict.items():
245
+ models_to_merge_layers_dict[name] = self.get_layers(model)
246
+
247
+ param_names_to_merge = None
248
+ for layer_idx, backbone_layer in tqdm(
249
+ enumerate(backbone_layers), desc="merging layers", total=num_layers
250
+ ):
251
+ # dictionary of list, where key is the parameter name,
252
+ # value is a list of the corresponding parameters of all the models that need to be merged
253
+ models_to_merge_param_dict = defaultdict(list)
254
+
255
+ # list of dictionaries with length len(models_to_merge),
256
+ # each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
257
+ models_to_merge_regmean_weights_list = []
258
+
259
+ for name, layers_to_merge in models_to_merge_layers_dict.items():
260
+ layer_to_merge = layers_to_merge[layer_idx]
261
+ param_dict = layer_to_merge.state_dict()
262
+
263
+ # exclude parameter whose name matches element in exclude_param_names_regex
264
+ if param_names_to_merge is None:
265
+ param_names_to_merge = get_param_names_to_merge(
266
+ input_param_names=list(param_dict.keys()),
267
+ exclude_param_names_regex=self.config.get(
268
+ "exclude_param_names_regex", []
269
+ ),
270
+ )
271
+
272
+ for param_name in param_names_to_merge:
273
+ models_to_merge_param_dict[param_name].append(
274
+ param_dict[param_name]
275
+ )
276
+
277
+ linear_modules_to_merge = get_modules_to_merge(
278
+ model=layer_to_merge,
279
+ include_module_types=self._include_module_type,
280
+ )
281
+ assert (
282
+ len(linear_modules_to_merge) > 0
283
+ ), "No linear modules to merge"
284
+
285
+ # 2.1. compute regmean weights for each model
286
+ with (
287
+ self.profile("merging models"),
288
+ self.profile("computing regmean weights"),
289
+ ):
290
+ regmean_weights = self.get_regmean_weights(
291
+ name,
292
+ layer_to_merge,
293
+ batches_input=batches_input_dict[name],
294
+ linear_modules_to_merge=linear_modules_to_merge,
295
+ )
296
+
297
+ module_subset = get_param_names_to_merge(
298
+ input_param_names=list(param_dict.keys()),
299
+ exclude_param_names_regex=self.exclude_param_names_regex,
300
+ )
301
+ module_subset = [
302
+ name.replace(".weight", "").replace(".bias", "")
303
+ for name in module_subset
304
+ ]
305
+ module_subset = list(set(module_subset))
306
+ regmean_weights = {
307
+ module_name: regmean_weights[module_name]
308
+ for module_name in module_subset
309
+ if module_name in regmean_weights
310
+ }
311
+
312
+ models_to_merge_regmean_weights_list.append(regmean_weights)
313
+
314
+ # 2.2. merge parameters with regmean weights
315
+ with self.profile("merging models"):
316
+ # merging with regmean weights
317
+ merged_layer_params = merging_with_regmean_weights(
318
+ models_to_merge_param_dict=models_to_merge_param_dict,
319
+ models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
320
+ reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
321
+ weight_transpose=self.config.get("weight_transpose", True),
322
+ )
323
+
324
+ merged_params_dict = self.update_merged_params_dict(
325
+ merged_params_dict=merged_params_dict,
326
+ new_merged_params=merged_layer_params,
327
+ layer_idx=layer_idx,
328
+ )
329
+
330
+ # 2.3. compute input for the next layer
331
+ with (
332
+ self.profile("merging models"),
333
+ self.profile("forwarding next layer"),
334
+ ):
335
+ if layer_idx < num_layers - 1:
336
+ backbone_layer.load_state_dict(
337
+ merged_layer_params, strict=False
338
+ )
339
+ batches_output_dict = defaultdict(list)
340
+ for name in models_to_merge_dict.keys():
341
+ batches_output_dict[name] = self.layer_batches_forward(
342
+ backbone_layer, batches_input_dict[name]
343
+ )
344
+ batches_input_dict = batches_output_dict
345
+
346
+ # 3. load state dict to the merged model
347
+ merged_model.load_state_dict(merged_params_dict, strict=False)
348
+
349
+ self.print_profile_summary()
350
+ return merged_model
351
+
352
+ def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
353
+ """
354
+ Merge the embedding layer of the model with the merged model.
355
+ This method should be implemented in subclasses if needed.
356
+ """
357
+ raise NotImplementedError()
358
+
359
+ def get_input_for_first_layer(self, model: nn.Module, train_dataset):
360
+ raise NotImplementedError
361
+
362
+ def get_layers(self, model: nn.Module):
363
+ raise NotImplementedError
364
+
365
+ def update_merged_params_dict(
366
+ self, merged_params_dict, new_merged_params, layer_idx
367
+ ):
368
+ raise NotImplementedError
369
+
370
+ def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
371
+ raise NotImplementedError
372
+
373
+ def on_regmean_start(self):
374
+ pass
375
+
376
+ def get_regmean_weights(
377
+ self,
378
+ model_name: str,
379
+ layer: nn.Module,
380
+ batches_input: List[Tensor],
381
+ linear_modules_to_merge: Dict[str, nn.Module],
382
+ ):
383
+ raise NotImplementedError