fusion-bench 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/constants/banner.py +12 -0
  3. fusion_bench/method/__init__.py +2 -0
  4. fusion_bench/method/linear/simple_average_for_llama.py +30 -5
  5. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  6. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +192 -0
  7. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +365 -0
  8. fusion_bench/method/simple_average.py +29 -3
  9. fusion_bench/modelpool/causal_lm/causal_lm.py +37 -6
  10. fusion_bench/modelpool/clip_vision/modelpool.py +45 -12
  11. fusion_bench/scripts/cli.py +1 -1
  12. fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
  13. fusion_bench/utils/lazy_state_dict.py +75 -3
  14. fusion_bench/utils/misc.py +66 -2
  15. fusion_bench/utils/modelscope.py +146 -0
  16. fusion_bench/utils/state_dict_arithmetic.py +10 -5
  17. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/METADATA +9 -1
  18. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/RECORD +50 -43
  19. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  20. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  21. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  22. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  23. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  24. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  25. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  26. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  27. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  28. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  29. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  30. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  31. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  32. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  33. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  34. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  35. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  36. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  37. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  38. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  39. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -0
  40. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  41. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  42. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  43. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  44. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  45. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  46. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
  47. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/WHEEL +0 -0
  48. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/entry_points.txt +0 -0
  49. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/licenses/LICENSE +0 -0
  50. {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py CHANGED
@@ -1,3 +1,9 @@
1
+ # ███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗
2
+ # ██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║
3
+ # █████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║
4
+ # ██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║
5
+ # ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
6
+ # ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
1
7
  # flake8: noqa: F401
2
8
  from . import (
3
9
  constants,
@@ -0,0 +1,12 @@
1
+ FUSION_BENCH_BANNER = (
2
+ ""
3
+ + "███████╗██╗ ██╗███████╗██╗ ██████╗ ███╗ ██╗ ██████╗ ███████╗███╗ ██╗ ██████╗██╗ ██╗\n"
4
+ + "██╔════╝██║ ██║██╔════╝██║██╔═══██╗████╗ ██║ ██╔══██╗██╔════╝████╗ ██║██╔════╝██║ ██║\n"
5
+ + "█████╗ ██║ ██║███████╗██║██║ ██║██╔██╗ ██║█████╗██████╔╝█████╗ ██╔██╗ ██║██║ ███████║\n"
6
+ + "██╔══╝ ██║ ██║╚════██║██║██║ ██║██║╚██╗██║╚════╝██╔══██╗██╔══╝ ██║╚██╗██║██║ ██╔══██║\n"
7
+ + "██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║\n"
8
+ + "╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝\n"
9
+ )
10
+
11
+ if __name__ == "__main__":
12
+ print(FUSION_BENCH_BANNER)
@@ -41,6 +41,7 @@ _import_structure = {
41
41
  "FisherMergingAlgorithmForGPT2",
42
42
  ],
43
43
  "regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
44
+ "regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
44
45
  "adamerging": [
45
46
  "CLIPTaskWiseAdaMergingAlgorithm",
46
47
  "CLIPLayerWiseAdaMergingAlgorithm",
@@ -195,6 +196,7 @@ if TYPE_CHECKING:
195
196
  )
196
197
  from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
197
198
  from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
199
+ from .regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
198
200
  from .simple_average import SimpleAverageAlgorithm
199
201
  from .slerp import SlerpMergeAlgorithm
200
202
  from .smile_upscaling import (
@@ -1,4 +1,5 @@
1
- from typing import Optional
1
+ from copy import deepcopy
2
+ from typing import TYPE_CHECKING, Optional
2
3
 
3
4
  from typing_extensions import override
4
5
 
@@ -6,6 +7,11 @@ from fusion_bench import timeit_context
6
7
  from fusion_bench.method.base_algorithm import BaseAlgorithm
7
8
  from fusion_bench.method.simple_average import SimpleAverageAlgorithm
8
9
  from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
10
+ from fusion_bench.utils.pylogger import getRankZeroLogger
11
+ from omegaconf import flag_override
12
+ from fusion_bench.utils import instantiate
13
+
14
+ log = getRankZeroLogger(__name__)
9
15
 
10
16
 
11
17
  class SimpleAverageForLlama(BaseAlgorithm):
@@ -26,12 +32,19 @@ class SimpleAverageForLlama(BaseAlgorithm):
26
32
 
27
33
  _config_mapping = BaseAlgorithm._config_mapping | {
28
34
  "merge_backbone": "merge_backbone",
35
+ "show_pbar": "show_pbar",
29
36
  }
30
37
 
31
- def __init__(self, merge_backbone: bool, model_save_path: Optional[str] = None):
38
+ def __init__(
39
+ self,
40
+ merge_backbone: bool,
41
+ model_save_path: Optional[str] = None,
42
+ show_pbar: bool = False,
43
+ ):
32
44
  super().__init__()
33
45
  self.merge_backbone = merge_backbone
34
46
  self.model_save_path = model_save_path
47
+ self.show_pbar = show_pbar
35
48
 
36
49
  @override
37
50
  def run(self, modelpool: CausalLMPool):
@@ -40,12 +53,24 @@ class SimpleAverageForLlama(BaseAlgorithm):
40
53
 
41
54
  if self.merge_backbone:
42
55
  assert modelpool.has_pretrained
43
- backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
56
+ log.info(
57
+ "Merging backbone of the model pool, use CausalLMBackbonePool instead of CausalLMPool."
58
+ )
59
+ modelpool_config = deepcopy(modelpool.config)
60
+ with flag_override(modelpool_config, "allow_objects", True):
61
+ modelpool_config._target_ = (
62
+ "fusion_bench.modelpool.causal_lm.CausalLMBackbonePool"
63
+ )
64
+ backbone_modelpool = instantiate(modelpool_config)
44
65
  model = modelpool.load_model("_pretrained_")
45
- backbone_model = SimpleAverageAlgorithm().run(backbone_modelpool)
66
+ backbone_model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
67
+ backbone_modelpool
68
+ )
46
69
  model.model.layers = backbone_model
47
70
  else:
48
- model = SimpleAverageAlgorithm().run()
71
+ model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
72
+ modelpool=modelpool
73
+ )
49
74
 
50
75
  if self.model_save_path is not None:
51
76
  with timeit_context(f"Saving the model to {self.model_save_path}"):
@@ -0,0 +1,3 @@
1
+ # flake8: noqa F401
2
+ from .clip_regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
3
+ from .regmean_plusplus import RegMeanAlgorithmPlusPlus
@@ -0,0 +1,192 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Dict, List, cast # noqa: F401
4
+
5
+ import torch
6
+ import torch.utils.data
7
+ from omegaconf import DictConfig
8
+ from torch import Tensor, nn
9
+ from torch.nn.modules import Module
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.autonotebook import tqdm
12
+
13
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
14
+ from fusion_bench.mixins import CLIPClassificationMixin
15
+
16
+ from .regmean_plusplus import RegMeanAlgorithmPlusPlus
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ class RegMeanAlgorithmForCLIPPlusPlus(
22
+ RegMeanAlgorithmPlusPlus,
23
+ CLIPClassificationMixin,
24
+ ):
25
+ _config_mapping = {
26
+ "_dataloader_kwargs": "dataloader_kwargs",
27
+ }
28
+
29
+ def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
30
+ super().__init__(**kwargs)
31
+ self._dataloader_kwargs = dataloader_kwargs
32
+
33
+ def on_regmean_start(self):
34
+ self.setup_zero_shot_classification_head()
35
+
36
+ def compute_logits(self, module, batch, task: str) -> Tensor:
37
+ images, _ = batch
38
+ text_embeds = self.zeroshot_weights[task]
39
+
40
+ image_embeds = module(images)[1]
41
+ image_embeds = self.visual_projection(image_embeds)
42
+
43
+ # normalize embeddings
44
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
45
+
46
+ # cosine similarity
47
+ logits_per_text = (
48
+ torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
49
+ )
50
+ logits_per_image = logits_per_text.t()
51
+
52
+ return logits_per_image
53
+
54
+ def get_regmean_weights(
55
+ self,
56
+ model_name: str,
57
+ layer: Module,
58
+ batches_input: List[Tensor],
59
+ linear_modules_to_merge: Dict[str, Module],
60
+ ):
61
+ layer = self.fabric.setup(layer)
62
+
63
+ def compute_regmean_weights(module_name: str):
64
+ """
65
+ compute the regmean weights, a hook function to deal with each module's input
66
+ :param module_name: str, module name
67
+ :return:
68
+ """
69
+
70
+ def hook(module: nn.Module, input: tuple, output: torch.Tensor):
71
+ # Tensor, shape (batch_size, sequence_length, hidden_dim)
72
+ x = cast(Tensor, input[0]).detach()
73
+ batch_num_actual_examples = x.shape[0]
74
+ # Tensor, shape (batch_size * sequence_length, hidden_dim)
75
+ x = x.reshape(-1, x.shape[-1])
76
+ # Tensor, shape (hidden_dim, hidden_dim)
77
+ xtx = torch.matmul(x.transpose(0, 1), x)
78
+ # store the averaged weights in regmean_weights
79
+ if module_name not in regmean_weights.keys():
80
+ regmean_weights[module_name] = xtx / x.shape[0]
81
+ num_computed_examples[module_name] = x.shape[0]
82
+ num_actual_examples[module_name] = batch_num_actual_examples
83
+ else:
84
+ regmean_weights[module_name] = (
85
+ regmean_weights[module_name]
86
+ * num_computed_examples[module_name]
87
+ + xtx
88
+ ) / (num_computed_examples[module_name] + x.shape[0])
89
+ num_computed_examples[module_name] += x.shape[0]
90
+ num_actual_examples[module_name] += batch_num_actual_examples
91
+
92
+ return hook
93
+
94
+ handles = []
95
+ # dictionary, regmean matrices for each linear module inputs
96
+ regmean_weights = {}
97
+ # dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
98
+ num_computed_examples = {}
99
+ # dictionary, number of actual examples used for computing regmean matrices
100
+ num_actual_examples = {}
101
+
102
+ for module_name, linear_module_to_merge in linear_modules_to_merge.items():
103
+ # register a hook in the forward process
104
+ handle = linear_module_to_merge.register_forward_hook(
105
+ compute_regmean_weights(module_name=module_name)
106
+ )
107
+ handles.append(handle)
108
+ _ = self.layer_batches_forward(layer, batches_input)
109
+
110
+ # remove the added hook
111
+ for handle in handles:
112
+ handle.remove()
113
+
114
+ for module_name in regmean_weights.keys():
115
+ regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
116
+
117
+ return regmean_weights
118
+
119
+ def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
120
+ models_to_merge_param_dict = defaultdict(list)
121
+
122
+ # get the parameters of the embedding layer from each model
123
+ for model_to_merge in models_to_merge_dict.values():
124
+ model_to_merge_state_dict = model_to_merge.state_dict()
125
+
126
+ param_dict = {}
127
+ for name, param in model_to_merge_state_dict.items():
128
+ if name.startswith("vision_model.embeddings") or name.startswith("vision_model.pre_layrnorm"):
129
+ param_dict[name] = param
130
+
131
+ for param_name in param_dict.keys():
132
+ models_to_merge_param_dict[param_name].append(
133
+ param_dict[param_name]
134
+ )
135
+
136
+ # merge the parameters of the embedding layer
137
+ merged_params_dict = {}
138
+ for param_name, param_list in models_to_merge_param_dict.items():
139
+ merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
140
+
141
+ return merged_params_dict
142
+
143
+
144
+ def get_input_for_first_layer(self, model: nn.Module, train_dataset):
145
+ # setup dataloader
146
+ train_dataset = CLIPDataset(train_dataset, self.clip_processor)
147
+ train_dataloader = DataLoader(
148
+ train_dataset, shuffle=True, **self._dataloader_kwargs
149
+ )
150
+ train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
151
+ model = self.fabric.setup(model)
152
+
153
+ def compute_input(model, batch):
154
+ images, _ = batch
155
+
156
+ images = images.to(model.device)
157
+ image_embeds = model.vision_model.embeddings(images)
158
+ image_embeds = model.vision_model.pre_layrnorm(image_embeds)
159
+ image_embeds = image_embeds.detach().cpu()
160
+
161
+ return image_embeds
162
+
163
+ num_computed_examples = 0
164
+ num_regmean_examples = self.num_regmean_examples
165
+
166
+ batches_input = []
167
+ for batch in train_dataloader:
168
+ if num_computed_examples >= num_regmean_examples:
169
+ break
170
+ batches_input.append(compute_input(model, batch))
171
+ num_computed_examples += batch[0].size(0)
172
+
173
+ return batches_input
174
+
175
+ def get_layers(self, model: nn.Module):
176
+ return model.vision_model.encoder.layers
177
+
178
+ def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
179
+ for key, value in new_merged_params.items():
180
+ key = f"vision_model.encoder.layers.{layer_idx}.{key}"
181
+ merged_params_dict[key] = value
182
+
183
+ return merged_params_dict
184
+
185
+ def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]) -> Tensor:
186
+ batches_output = []
187
+ for batch in batches_input:
188
+ device = next(layer.parameters()).device
189
+ batch = batch.to(device)
190
+ logits = layer(batch, attention_mask=None, causal_attention_mask=None)[0].detach().cpu()
191
+ batches_output.append(logits)
192
+ return batches_output
@@ -0,0 +1,365 @@
1
+ import logging
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import Dict, List, cast
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from tqdm.autonotebook import tqdm
9
+
10
+ from fusion_bench.method import BaseAlgorithm
11
+ from fusion_bench.mixins import SimpleProfilerMixin
12
+ from fusion_bench.modelpool import BaseModelPool
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ def get_param_names_to_merge(
18
+ input_param_names: List[str], exclude_param_names_regex: list
19
+ ):
20
+ """
21
+ get the names of parameters that need to be merged
22
+ :param input_param_names: list, names of input parameters
23
+ :param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
24
+ :return:
25
+ """
26
+ param_names_to_merge = []
27
+ for param_name in input_param_names:
28
+ exclude = any(
29
+ [
30
+ re.match(exclude_pattern, param_name)
31
+ for exclude_pattern in exclude_param_names_regex
32
+ ]
33
+ )
34
+ if not exclude:
35
+ param_names_to_merge.append(param_name)
36
+ return param_names_to_merge
37
+
38
+
39
+ def get_modules_to_merge(model: nn.Module, include_module_types: list):
40
+ """
41
+ get the model modules that need to be merged, whose type is in include_module_types
42
+ :param model: nn.Module, input model
43
+ :param include_module_types: list, module types that want to include
44
+ :return:
45
+ """
46
+ modules_to_merge: Dict[str, nn.Module] = {}
47
+ for module_name, module in model.named_modules():
48
+ is_valid_type = not include_module_types or any(
49
+ [
50
+ isinstance(module, include_module_type)
51
+ for include_module_type in include_module_types
52
+ ]
53
+ )
54
+ if is_valid_type:
55
+ modules_to_merge[module_name] = module
56
+ return modules_to_merge
57
+
58
+
59
+ def reduce_non_diagonal_elements(
60
+ regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
61
+ ):
62
+ """
63
+ reduce the non-diagonal elements in regmean_weights
64
+ :param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
65
+ :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
66
+ :return:
67
+ """
68
+ # diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
69
+ diag_weights = torch.diag(
70
+ torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
71
+ ).to(regmean_weights.device)
72
+ # matrix with reduce_non_diagonal_ratio as elements
73
+ non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
74
+ # diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
75
+ return regmean_weights * (diag_weights + non_diag_weights)
76
+
77
+
78
+ def regmean_params_merge(
79
+ param_weight_list: List[Tensor],
80
+ param_regmean_list: List[Tensor],
81
+ reduce_non_diagonal_ratio: float = 1.0,
82
+ weight_transpose: bool = True,
83
+ module_name: str = "",
84
+ device = "cpu"
85
+ ):
86
+ # two lists with length num_models_to_merge
87
+ param_multiplied_results, module_regmean_weights_list = [], []
88
+ for model_idx, module_regmean_weights in enumerate(
89
+ param_regmean_list
90
+ ):
91
+ # reduce non-diagonal elements
92
+ module_regmean_weights = reduce_non_diagonal_elements(
93
+ regmean_weights=module_regmean_weights,
94
+ reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
95
+ )
96
+ module_regmean_weights_list.append(module_regmean_weights)
97
+
98
+ model_to_merge_param = param_weight_list[model_idx]
99
+ # since the weight shape of Linear module is (output_size, input_size), we need to transpose it
100
+ param_multiplied_results.append(
101
+ torch.matmul(
102
+ module_regmean_weights,
103
+ (
104
+ model_to_merge_param.transpose(0, 1)
105
+ if weight_transpose
106
+ else model_to_merge_param
107
+ ),
108
+ )
109
+ )
110
+
111
+ # sum up module_regmean_weights and param_multiplied_results over all individual models
112
+ sum_module_regmean_weights = sum(module_regmean_weights_list)
113
+ sum_param_multiplied_results = sum(param_multiplied_results)
114
+
115
+ # get the inverse matrix
116
+ inv_sum_module_regmean_weights = torch.inverse(
117
+ sum_module_regmean_weights
118
+ )
119
+ # merge parameters with regmean
120
+ merged_param = torch.matmul(
121
+ inv_sum_module_regmean_weights, sum_param_multiplied_results
122
+ )
123
+ # transpose to the original shape of "weight" in Linear module
124
+ merged_param = merged_param.transpose(0, 1) if weight_transpose else merged_param
125
+
126
+ return merged_param
127
+
128
+
129
+ def merging_with_regmean_weights(
130
+ models_to_merge_param_dict: dict,
131
+ models_to_merge_regmean_weights_list: list,
132
+ reduce_non_diagonal_ratio: float = 1.0,
133
+ weight_transpose: bool = True,
134
+ ):
135
+ """
136
+ merge parameters of different models with computed regmean weights
137
+ :param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
138
+ value is a list of the corresponding parameters of all the models that need to be merged
139
+ :param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
140
+ each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
141
+ :param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
142
+ :return:
143
+ """
144
+ # dict, dictionary of model parameters
145
+ merged_params = {}
146
+
147
+ for param_name, param_value_list in models_to_merge_param_dict.items():
148
+ merged_by_regmean = False
149
+ # only perform regmean merging on the "weight" parameter of Linear module
150
+ if param_name.endswith(".weight"):
151
+ module_name = param_name[: -len(".weight")]
152
+ if module_name in models_to_merge_regmean_weights_list[0].keys():
153
+ # two lists with length num_models_to_merge
154
+ module_regmean_weights_list = []
155
+ for model_idx, model_to_merge_regmean_weights in enumerate(
156
+ models_to_merge_regmean_weights_list
157
+ ):
158
+ device = param_value_list[model_idx].device
159
+
160
+ # Tensor, shape (hidden_dim, hidden_dim)
161
+ module_regmean_weights = model_to_merge_regmean_weights[module_name].to(device)
162
+ module_regmean_weights_list.append(module_regmean_weights)
163
+
164
+ merged_params[param_name] = regmean_params_merge(param_weight_list=param_value_list,
165
+ param_regmean_list=module_regmean_weights_list,
166
+ reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
167
+ weight_transpose=weight_transpose,
168
+ module_name=module_name,
169
+ device=device)
170
+
171
+ merged_by_regmean = True
172
+ # use average merging for parameters whose names are not end with ".weight" or not in Linear module
173
+ if not merged_by_regmean:
174
+ merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
175
+
176
+ return merged_params
177
+
178
+
179
+ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
180
+ _include_module_type = [nn.Linear]
181
+ _config_mapping = {
182
+ "num_regmean_examples": "num_regmean_examples",
183
+ "exclude_param_names_regex": "exclude_param_names_regex",
184
+ "reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
185
+ "weight_transpose": "weight_transpose",
186
+ }
187
+
188
+ def __init__(
189
+ self,
190
+ *,
191
+ num_regmean_examples: int,
192
+ exclude_param_names_regex: list,
193
+ reduce_non_diagonal_ratio: float,
194
+ weight_transpose: bool,
195
+ **kwargs,
196
+ ):
197
+ self.num_regmean_examples = num_regmean_examples
198
+ self.exclude_param_names_regex = exclude_param_names_regex
199
+ self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
200
+ self.weight_transpose = weight_transpose
201
+ super().__init__(**kwargs)
202
+
203
+ def run(self, modelpool: BaseModelPool, **kwargs):
204
+ if not isinstance(modelpool, BaseModelPool):
205
+ modelpool = BaseModelPool(modelpool)
206
+ self.modelpool = modelpool
207
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
208
+ models_to_merge_dict = {name: model.to(device) for name, model in modelpool.named_models()}
209
+ self.on_regmean_start()
210
+
211
+ # initialize the merged models as the pretrained model
212
+ merged_model = modelpool.load_pretrained_model().to(device)
213
+ merged_params_dict = {}
214
+
215
+ # 1. merge embedding layer
216
+ merged_embedding_dict = self.merge_embedding_layer(models_to_merge_dict=models_to_merge_dict)
217
+ merged_model.load_state_dict(merged_embedding_dict, strict=False)
218
+
219
+ with torch.no_grad():
220
+ # 1.1. compute input for the first layer
221
+ with (
222
+ self.profile("merging models"),
223
+ self.profile("computing first layer input"),
224
+ ):
225
+ batches_input_dict = defaultdict(list)
226
+ for name in tqdm(models_to_merge_dict.keys(), desc="computing input for first layer"):
227
+ dataset = modelpool.load_train_dataset(name)
228
+
229
+ batches_input_dict[name] = self.get_input_for_first_layer(
230
+ merged_model,
231
+ dataset
232
+ )
233
+
234
+ # 2. iteratively merge layer by layer with regmean algorithm
235
+ backbone_layers = self.get_layers(merged_model)
236
+ num_layers = len(backbone_layers)
237
+
238
+ models_to_merge_layers_dict = defaultdict(list)
239
+ for name, model in models_to_merge_dict.items():
240
+ models_to_merge_layers_dict[name] = self.get_layers(model)
241
+
242
+ param_names_to_merge = None
243
+ for layer_idx, backbone_layer in tqdm(enumerate(backbone_layers),
244
+ desc="merging layers",
245
+ total=num_layers):
246
+ # dictionary of list, where key is the parameter name,
247
+ # value is a list of the corresponding parameters of all the models that need to be merged
248
+ models_to_merge_param_dict = defaultdict(list)
249
+
250
+ # list of dictionaries with length len(models_to_merge),
251
+ # each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
252
+ models_to_merge_regmean_weights_list = []
253
+
254
+ for name, layers_to_merge in models_to_merge_layers_dict.items():
255
+ layer_to_merge = layers_to_merge[layer_idx]
256
+ param_dict = layer_to_merge.state_dict()
257
+
258
+ # exclude parameter whose name matches element in exclude_param_names_regex
259
+ if param_names_to_merge is None:
260
+ param_names_to_merge = get_param_names_to_merge(
261
+ input_param_names=list(param_dict.keys()),
262
+ exclude_param_names_regex=self.config.get(
263
+ "exclude_param_names_regex", []
264
+ ),
265
+ )
266
+
267
+ for param_name in param_names_to_merge:
268
+ models_to_merge_param_dict[param_name].append(
269
+ param_dict[param_name]
270
+ )
271
+
272
+ linear_modules_to_merge = get_modules_to_merge(
273
+ model=layer_to_merge, include_module_types=self._include_module_type
274
+ )
275
+ assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
276
+
277
+ # 2.1. compute regmean weights for each model
278
+ with (
279
+ self.profile("merging models"),
280
+ self.profile("computing regmean weights"),
281
+ ):
282
+ regmean_weights = self.get_regmean_weights(
283
+ name,
284
+ layer_to_merge,
285
+ batches_input=batches_input_dict[name],
286
+ linear_modules_to_merge=linear_modules_to_merge,
287
+ )
288
+
289
+ module_subset = get_param_names_to_merge(
290
+ input_param_names=list(param_dict.keys()),
291
+ exclude_param_names_regex=self.exclude_param_names_regex
292
+ )
293
+ module_subset = [name.replace(".weight", "").replace(".bias", "") for name in module_subset]
294
+ module_subset = list(set(module_subset))
295
+ regmean_weights = {module_name: regmean_weights[module_name] for module_name in module_subset if module_name in regmean_weights}
296
+
297
+ models_to_merge_regmean_weights_list.append(regmean_weights)
298
+
299
+ # 2.2. merge parameters with regmean weights
300
+ with self.profile("merging models"):
301
+ # merging with regmean weights
302
+ merged_layer_params = merging_with_regmean_weights(
303
+ models_to_merge_param_dict=models_to_merge_param_dict,
304
+ models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
305
+ reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
306
+ weight_transpose=self.config.get("weight_transpose", True),
307
+ )
308
+
309
+ merged_params_dict = self.update_merged_params_dict(
310
+ merged_params_dict=merged_params_dict,
311
+ new_merged_params=merged_layer_params,
312
+ layer_idx=layer_idx,
313
+ )
314
+
315
+ # 2.3. compute input for the next layer
316
+ with (
317
+ self.profile("merging models"),
318
+ self.profile("forwarding next layer"),
319
+ ):
320
+ if layer_idx < num_layers - 1:
321
+ backbone_layer.load_state_dict(merged_layer_params, strict=False)
322
+ batches_output_dict = defaultdict(list)
323
+ for name in models_to_merge_dict.keys():
324
+ batches_output_dict[name] = self.layer_batches_forward(
325
+ backbone_layer,
326
+ batches_input_dict[name]
327
+ )
328
+ batches_input_dict = batches_output_dict
329
+
330
+ # 3. load state dict to the merged model
331
+ merged_model.load_state_dict(merged_params_dict, strict=False)
332
+
333
+ self.print_profile_summary()
334
+ return merged_model
335
+
336
+ def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
337
+ """
338
+ Merge the embedding layer of the model with the merged model.
339
+ This method should be implemented in subclasses if needed.
340
+ """
341
+ raise NotImplementedError()
342
+
343
+ def get_input_for_first_layer(self, model: nn.Module, train_dataset):
344
+ raise NotImplementedError
345
+
346
+ def get_layers(self, model: nn.Module):
347
+ raise NotImplementedError
348
+
349
+ def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
350
+ raise NotImplementedError
351
+
352
+ def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
353
+ raise NotImplementedError
354
+
355
+ def on_regmean_start(self):
356
+ pass
357
+
358
+ def get_regmean_weights(
359
+ self,
360
+ model_name: str,
361
+ layer: nn.Module,
362
+ batches_input: List[Tensor],
363
+ linear_modules_to_merge: Dict[str, nn.Module],
364
+ ):
365
+ raise NotImplementedError