fusion-bench 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/linear/simple_average_for_llama.py +14 -3
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +192 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +365 -0
- fusion_bench/method/simple_average.py +18 -2
- fusion_bench/modelpool/clip_vision/modelpool.py +45 -12
- fusion_bench/scripts/cli.py +1 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +146 -0
- fusion_bench/utils/state_dict_arithmetic.py +10 -5
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/METADATA +9 -1
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/RECORD +44 -39
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/top_level.txt +0 -0
fusion_bench/method/__init__.py
CHANGED
|
@@ -41,6 +41,7 @@ _import_structure = {
|
|
|
41
41
|
"FisherMergingAlgorithmForGPT2",
|
|
42
42
|
],
|
|
43
43
|
"regmean": ["RegMeanAlgorithmForCLIP", "RegMeanAlgorithmForGPT2"],
|
|
44
|
+
"regmean_plusplus": ["RegMeanAlgorithmForCLIPPlusPlus"],
|
|
44
45
|
"adamerging": [
|
|
45
46
|
"CLIPTaskWiseAdaMergingAlgorithm",
|
|
46
47
|
"CLIPLayerWiseAdaMergingAlgorithm",
|
|
@@ -195,6 +196,7 @@ if TYPE_CHECKING:
|
|
|
195
196
|
)
|
|
196
197
|
from .rankone_moe import CLIPRankOneMoEAlgorithm, RankOneMoEAlgorithm
|
|
197
198
|
from .regmean import RegMeanAlgorithmForCLIP, RegMeanAlgorithmForGPT2
|
|
199
|
+
from .regmean_plusplus import RegMeanAlgorithmForCLIPPlusPlus
|
|
198
200
|
from .simple_average import SimpleAverageAlgorithm
|
|
199
201
|
from .slerp import SlerpMergeAlgorithm
|
|
200
202
|
from .smile_upscaling import (
|
|
@@ -32,12 +32,19 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
32
32
|
|
|
33
33
|
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
34
34
|
"merge_backbone": "merge_backbone",
|
|
35
|
+
"show_pbar": "show_pbar",
|
|
35
36
|
}
|
|
36
37
|
|
|
37
|
-
def __init__(
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
merge_backbone: bool,
|
|
41
|
+
model_save_path: Optional[str] = None,
|
|
42
|
+
show_pbar: bool = False,
|
|
43
|
+
):
|
|
38
44
|
super().__init__()
|
|
39
45
|
self.merge_backbone = merge_backbone
|
|
40
46
|
self.model_save_path = model_save_path
|
|
47
|
+
self.show_pbar = show_pbar
|
|
41
48
|
|
|
42
49
|
@override
|
|
43
50
|
def run(self, modelpool: CausalLMPool):
|
|
@@ -56,10 +63,14 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
56
63
|
)
|
|
57
64
|
backbone_modelpool = instantiate(modelpool_config)
|
|
58
65
|
model = modelpool.load_model("_pretrained_")
|
|
59
|
-
backbone_model = SimpleAverageAlgorithm().run(
|
|
66
|
+
backbone_model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
|
|
67
|
+
backbone_modelpool
|
|
68
|
+
)
|
|
60
69
|
model.model.layers = backbone_model
|
|
61
70
|
else:
|
|
62
|
-
model = SimpleAverageAlgorithm().run(
|
|
71
|
+
model = SimpleAverageAlgorithm(show_pbar=self.show_pbar).run(
|
|
72
|
+
modelpool=modelpool
|
|
73
|
+
)
|
|
63
74
|
|
|
64
75
|
if self.model_save_path is not None:
|
|
65
76
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Dict, List, cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.utils.data
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from torch.nn.modules import Module
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from tqdm.autonotebook import tqdm
|
|
12
|
+
|
|
13
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
14
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
15
|
+
|
|
16
|
+
from .regmean_plusplus import RegMeanAlgorithmPlusPlus
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RegMeanAlgorithmForCLIPPlusPlus(
|
|
22
|
+
RegMeanAlgorithmPlusPlus,
|
|
23
|
+
CLIPClassificationMixin,
|
|
24
|
+
):
|
|
25
|
+
_config_mapping = {
|
|
26
|
+
"_dataloader_kwargs": "dataloader_kwargs",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
30
|
+
super().__init__(**kwargs)
|
|
31
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
32
|
+
|
|
33
|
+
def on_regmean_start(self):
|
|
34
|
+
self.setup_zero_shot_classification_head()
|
|
35
|
+
|
|
36
|
+
def compute_logits(self, module, batch, task: str) -> Tensor:
|
|
37
|
+
images, _ = batch
|
|
38
|
+
text_embeds = self.zeroshot_weights[task]
|
|
39
|
+
|
|
40
|
+
image_embeds = module(images)[1]
|
|
41
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
42
|
+
|
|
43
|
+
# normalize embeddings
|
|
44
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
45
|
+
|
|
46
|
+
# cosine similarity
|
|
47
|
+
logits_per_text = (
|
|
48
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
49
|
+
)
|
|
50
|
+
logits_per_image = logits_per_text.t()
|
|
51
|
+
|
|
52
|
+
return logits_per_image
|
|
53
|
+
|
|
54
|
+
def get_regmean_weights(
|
|
55
|
+
self,
|
|
56
|
+
model_name: str,
|
|
57
|
+
layer: Module,
|
|
58
|
+
batches_input: List[Tensor],
|
|
59
|
+
linear_modules_to_merge: Dict[str, Module],
|
|
60
|
+
):
|
|
61
|
+
layer = self.fabric.setup(layer)
|
|
62
|
+
|
|
63
|
+
def compute_regmean_weights(module_name: str):
|
|
64
|
+
"""
|
|
65
|
+
compute the regmean weights, a hook function to deal with each module's input
|
|
66
|
+
:param module_name: str, module name
|
|
67
|
+
:return:
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def hook(module: nn.Module, input: tuple, output: torch.Tensor):
|
|
71
|
+
# Tensor, shape (batch_size, sequence_length, hidden_dim)
|
|
72
|
+
x = cast(Tensor, input[0]).detach()
|
|
73
|
+
batch_num_actual_examples = x.shape[0]
|
|
74
|
+
# Tensor, shape (batch_size * sequence_length, hidden_dim)
|
|
75
|
+
x = x.reshape(-1, x.shape[-1])
|
|
76
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
77
|
+
xtx = torch.matmul(x.transpose(0, 1), x)
|
|
78
|
+
# store the averaged weights in regmean_weights
|
|
79
|
+
if module_name not in regmean_weights.keys():
|
|
80
|
+
regmean_weights[module_name] = xtx / x.shape[0]
|
|
81
|
+
num_computed_examples[module_name] = x.shape[0]
|
|
82
|
+
num_actual_examples[module_name] = batch_num_actual_examples
|
|
83
|
+
else:
|
|
84
|
+
regmean_weights[module_name] = (
|
|
85
|
+
regmean_weights[module_name]
|
|
86
|
+
* num_computed_examples[module_name]
|
|
87
|
+
+ xtx
|
|
88
|
+
) / (num_computed_examples[module_name] + x.shape[0])
|
|
89
|
+
num_computed_examples[module_name] += x.shape[0]
|
|
90
|
+
num_actual_examples[module_name] += batch_num_actual_examples
|
|
91
|
+
|
|
92
|
+
return hook
|
|
93
|
+
|
|
94
|
+
handles = []
|
|
95
|
+
# dictionary, regmean matrices for each linear module inputs
|
|
96
|
+
regmean_weights = {}
|
|
97
|
+
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
|
|
98
|
+
num_computed_examples = {}
|
|
99
|
+
# dictionary, number of actual examples used for computing regmean matrices
|
|
100
|
+
num_actual_examples = {}
|
|
101
|
+
|
|
102
|
+
for module_name, linear_module_to_merge in linear_modules_to_merge.items():
|
|
103
|
+
# register a hook in the forward process
|
|
104
|
+
handle = linear_module_to_merge.register_forward_hook(
|
|
105
|
+
compute_regmean_weights(module_name=module_name)
|
|
106
|
+
)
|
|
107
|
+
handles.append(handle)
|
|
108
|
+
_ = self.layer_batches_forward(layer, batches_input)
|
|
109
|
+
|
|
110
|
+
# remove the added hook
|
|
111
|
+
for handle in handles:
|
|
112
|
+
handle.remove()
|
|
113
|
+
|
|
114
|
+
for module_name in regmean_weights.keys():
|
|
115
|
+
regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
|
|
116
|
+
|
|
117
|
+
return regmean_weights
|
|
118
|
+
|
|
119
|
+
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
|
|
120
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
121
|
+
|
|
122
|
+
# get the parameters of the embedding layer from each model
|
|
123
|
+
for model_to_merge in models_to_merge_dict.values():
|
|
124
|
+
model_to_merge_state_dict = model_to_merge.state_dict()
|
|
125
|
+
|
|
126
|
+
param_dict = {}
|
|
127
|
+
for name, param in model_to_merge_state_dict.items():
|
|
128
|
+
if name.startswith("vision_model.embeddings") or name.startswith("vision_model.pre_layrnorm"):
|
|
129
|
+
param_dict[name] = param
|
|
130
|
+
|
|
131
|
+
for param_name in param_dict.keys():
|
|
132
|
+
models_to_merge_param_dict[param_name].append(
|
|
133
|
+
param_dict[param_name]
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# merge the parameters of the embedding layer
|
|
137
|
+
merged_params_dict = {}
|
|
138
|
+
for param_name, param_list in models_to_merge_param_dict.items():
|
|
139
|
+
merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
|
|
140
|
+
|
|
141
|
+
return merged_params_dict
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def get_input_for_first_layer(self, model: nn.Module, train_dataset):
|
|
145
|
+
# setup dataloader
|
|
146
|
+
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
147
|
+
train_dataloader = DataLoader(
|
|
148
|
+
train_dataset, shuffle=True, **self._dataloader_kwargs
|
|
149
|
+
)
|
|
150
|
+
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
151
|
+
model = self.fabric.setup(model)
|
|
152
|
+
|
|
153
|
+
def compute_input(model, batch):
|
|
154
|
+
images, _ = batch
|
|
155
|
+
|
|
156
|
+
images = images.to(model.device)
|
|
157
|
+
image_embeds = model.vision_model.embeddings(images)
|
|
158
|
+
image_embeds = model.vision_model.pre_layrnorm(image_embeds)
|
|
159
|
+
image_embeds = image_embeds.detach().cpu()
|
|
160
|
+
|
|
161
|
+
return image_embeds
|
|
162
|
+
|
|
163
|
+
num_computed_examples = 0
|
|
164
|
+
num_regmean_examples = self.num_regmean_examples
|
|
165
|
+
|
|
166
|
+
batches_input = []
|
|
167
|
+
for batch in train_dataloader:
|
|
168
|
+
if num_computed_examples >= num_regmean_examples:
|
|
169
|
+
break
|
|
170
|
+
batches_input.append(compute_input(model, batch))
|
|
171
|
+
num_computed_examples += batch[0].size(0)
|
|
172
|
+
|
|
173
|
+
return batches_input
|
|
174
|
+
|
|
175
|
+
def get_layers(self, model: nn.Module):
|
|
176
|
+
return model.vision_model.encoder.layers
|
|
177
|
+
|
|
178
|
+
def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
|
|
179
|
+
for key, value in new_merged_params.items():
|
|
180
|
+
key = f"vision_model.encoder.layers.{layer_idx}.{key}"
|
|
181
|
+
merged_params_dict[key] = value
|
|
182
|
+
|
|
183
|
+
return merged_params_dict
|
|
184
|
+
|
|
185
|
+
def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]) -> Tensor:
|
|
186
|
+
batches_output = []
|
|
187
|
+
for batch in batches_input:
|
|
188
|
+
device = next(layer.parameters()).device
|
|
189
|
+
batch = batch.to(device)
|
|
190
|
+
logits = layer(batch, attention_mask=None, causal_attention_mask=None)[0].detach().cpu()
|
|
191
|
+
batches_output.append(logits)
|
|
192
|
+
return batches_output
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from tqdm.autonotebook import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.method import BaseAlgorithm
|
|
11
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_param_names_to_merge(
|
|
18
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
get the names of parameters that need to be merged
|
|
22
|
+
:param input_param_names: list, names of input parameters
|
|
23
|
+
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
24
|
+
:return:
|
|
25
|
+
"""
|
|
26
|
+
param_names_to_merge = []
|
|
27
|
+
for param_name in input_param_names:
|
|
28
|
+
exclude = any(
|
|
29
|
+
[
|
|
30
|
+
re.match(exclude_pattern, param_name)
|
|
31
|
+
for exclude_pattern in exclude_param_names_regex
|
|
32
|
+
]
|
|
33
|
+
)
|
|
34
|
+
if not exclude:
|
|
35
|
+
param_names_to_merge.append(param_name)
|
|
36
|
+
return param_names_to_merge
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_modules_to_merge(model: nn.Module, include_module_types: list):
|
|
40
|
+
"""
|
|
41
|
+
get the model modules that need to be merged, whose type is in include_module_types
|
|
42
|
+
:param model: nn.Module, input model
|
|
43
|
+
:param include_module_types: list, module types that want to include
|
|
44
|
+
:return:
|
|
45
|
+
"""
|
|
46
|
+
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
+
for module_name, module in model.named_modules():
|
|
48
|
+
is_valid_type = not include_module_types or any(
|
|
49
|
+
[
|
|
50
|
+
isinstance(module, include_module_type)
|
|
51
|
+
for include_module_type in include_module_types
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
if is_valid_type:
|
|
55
|
+
modules_to_merge[module_name] = module
|
|
56
|
+
return modules_to_merge
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def reduce_non_diagonal_elements(
|
|
60
|
+
regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
reduce the non-diagonal elements in regmean_weights
|
|
64
|
+
:param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
|
|
65
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
66
|
+
:return:
|
|
67
|
+
"""
|
|
68
|
+
# diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
|
|
69
|
+
diag_weights = torch.diag(
|
|
70
|
+
torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
|
|
71
|
+
).to(regmean_weights.device)
|
|
72
|
+
# matrix with reduce_non_diagonal_ratio as elements
|
|
73
|
+
non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
|
|
74
|
+
# diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
|
|
75
|
+
return regmean_weights * (diag_weights + non_diag_weights)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def regmean_params_merge(
|
|
79
|
+
param_weight_list: List[Tensor],
|
|
80
|
+
param_regmean_list: List[Tensor],
|
|
81
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
82
|
+
weight_transpose: bool = True,
|
|
83
|
+
module_name: str = "",
|
|
84
|
+
device = "cpu"
|
|
85
|
+
):
|
|
86
|
+
# two lists with length num_models_to_merge
|
|
87
|
+
param_multiplied_results, module_regmean_weights_list = [], []
|
|
88
|
+
for model_idx, module_regmean_weights in enumerate(
|
|
89
|
+
param_regmean_list
|
|
90
|
+
):
|
|
91
|
+
# reduce non-diagonal elements
|
|
92
|
+
module_regmean_weights = reduce_non_diagonal_elements(
|
|
93
|
+
regmean_weights=module_regmean_weights,
|
|
94
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
95
|
+
)
|
|
96
|
+
module_regmean_weights_list.append(module_regmean_weights)
|
|
97
|
+
|
|
98
|
+
model_to_merge_param = param_weight_list[model_idx]
|
|
99
|
+
# since the weight shape of Linear module is (output_size, input_size), we need to transpose it
|
|
100
|
+
param_multiplied_results.append(
|
|
101
|
+
torch.matmul(
|
|
102
|
+
module_regmean_weights,
|
|
103
|
+
(
|
|
104
|
+
model_to_merge_param.transpose(0, 1)
|
|
105
|
+
if weight_transpose
|
|
106
|
+
else model_to_merge_param
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# sum up module_regmean_weights and param_multiplied_results over all individual models
|
|
112
|
+
sum_module_regmean_weights = sum(module_regmean_weights_list)
|
|
113
|
+
sum_param_multiplied_results = sum(param_multiplied_results)
|
|
114
|
+
|
|
115
|
+
# get the inverse matrix
|
|
116
|
+
inv_sum_module_regmean_weights = torch.inverse(
|
|
117
|
+
sum_module_regmean_weights
|
|
118
|
+
)
|
|
119
|
+
# merge parameters with regmean
|
|
120
|
+
merged_param = torch.matmul(
|
|
121
|
+
inv_sum_module_regmean_weights, sum_param_multiplied_results
|
|
122
|
+
)
|
|
123
|
+
# transpose to the original shape of "weight" in Linear module
|
|
124
|
+
merged_param = merged_param.transpose(0, 1) if weight_transpose else merged_param
|
|
125
|
+
|
|
126
|
+
return merged_param
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def merging_with_regmean_weights(
|
|
130
|
+
models_to_merge_param_dict: dict,
|
|
131
|
+
models_to_merge_regmean_weights_list: list,
|
|
132
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
133
|
+
weight_transpose: bool = True,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
merge parameters of different models with computed regmean weights
|
|
137
|
+
:param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
138
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
139
|
+
:param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
140
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
141
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
142
|
+
:return:
|
|
143
|
+
"""
|
|
144
|
+
# dict, dictionary of model parameters
|
|
145
|
+
merged_params = {}
|
|
146
|
+
|
|
147
|
+
for param_name, param_value_list in models_to_merge_param_dict.items():
|
|
148
|
+
merged_by_regmean = False
|
|
149
|
+
# only perform regmean merging on the "weight" parameter of Linear module
|
|
150
|
+
if param_name.endswith(".weight"):
|
|
151
|
+
module_name = param_name[: -len(".weight")]
|
|
152
|
+
if module_name in models_to_merge_regmean_weights_list[0].keys():
|
|
153
|
+
# two lists with length num_models_to_merge
|
|
154
|
+
module_regmean_weights_list = []
|
|
155
|
+
for model_idx, model_to_merge_regmean_weights in enumerate(
|
|
156
|
+
models_to_merge_regmean_weights_list
|
|
157
|
+
):
|
|
158
|
+
device = param_value_list[model_idx].device
|
|
159
|
+
|
|
160
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
161
|
+
module_regmean_weights = model_to_merge_regmean_weights[module_name].to(device)
|
|
162
|
+
module_regmean_weights_list.append(module_regmean_weights)
|
|
163
|
+
|
|
164
|
+
merged_params[param_name] = regmean_params_merge(param_weight_list=param_value_list,
|
|
165
|
+
param_regmean_list=module_regmean_weights_list,
|
|
166
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
167
|
+
weight_transpose=weight_transpose,
|
|
168
|
+
module_name=module_name,
|
|
169
|
+
device=device)
|
|
170
|
+
|
|
171
|
+
merged_by_regmean = True
|
|
172
|
+
# use average merging for parameters whose names are not end with ".weight" or not in Linear module
|
|
173
|
+
if not merged_by_regmean:
|
|
174
|
+
merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
|
|
175
|
+
|
|
176
|
+
return merged_params
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
180
|
+
_include_module_type = [nn.Linear]
|
|
181
|
+
_config_mapping = {
|
|
182
|
+
"num_regmean_examples": "num_regmean_examples",
|
|
183
|
+
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
184
|
+
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
185
|
+
"weight_transpose": "weight_transpose",
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
*,
|
|
191
|
+
num_regmean_examples: int,
|
|
192
|
+
exclude_param_names_regex: list,
|
|
193
|
+
reduce_non_diagonal_ratio: float,
|
|
194
|
+
weight_transpose: bool,
|
|
195
|
+
**kwargs,
|
|
196
|
+
):
|
|
197
|
+
self.num_regmean_examples = num_regmean_examples
|
|
198
|
+
self.exclude_param_names_regex = exclude_param_names_regex
|
|
199
|
+
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
200
|
+
self.weight_transpose = weight_transpose
|
|
201
|
+
super().__init__(**kwargs)
|
|
202
|
+
|
|
203
|
+
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
204
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
205
|
+
modelpool = BaseModelPool(modelpool)
|
|
206
|
+
self.modelpool = modelpool
|
|
207
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
208
|
+
models_to_merge_dict = {name: model.to(device) for name, model in modelpool.named_models()}
|
|
209
|
+
self.on_regmean_start()
|
|
210
|
+
|
|
211
|
+
# initialize the merged models as the pretrained model
|
|
212
|
+
merged_model = modelpool.load_pretrained_model().to(device)
|
|
213
|
+
merged_params_dict = {}
|
|
214
|
+
|
|
215
|
+
# 1. merge embedding layer
|
|
216
|
+
merged_embedding_dict = self.merge_embedding_layer(models_to_merge_dict=models_to_merge_dict)
|
|
217
|
+
merged_model.load_state_dict(merged_embedding_dict, strict=False)
|
|
218
|
+
|
|
219
|
+
with torch.no_grad():
|
|
220
|
+
# 1.1. compute input for the first layer
|
|
221
|
+
with (
|
|
222
|
+
self.profile("merging models"),
|
|
223
|
+
self.profile("computing first layer input"),
|
|
224
|
+
):
|
|
225
|
+
batches_input_dict = defaultdict(list)
|
|
226
|
+
for name in tqdm(models_to_merge_dict.keys(), desc="computing input for first layer"):
|
|
227
|
+
dataset = modelpool.load_train_dataset(name)
|
|
228
|
+
|
|
229
|
+
batches_input_dict[name] = self.get_input_for_first_layer(
|
|
230
|
+
merged_model,
|
|
231
|
+
dataset
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# 2. iteratively merge layer by layer with regmean algorithm
|
|
235
|
+
backbone_layers = self.get_layers(merged_model)
|
|
236
|
+
num_layers = len(backbone_layers)
|
|
237
|
+
|
|
238
|
+
models_to_merge_layers_dict = defaultdict(list)
|
|
239
|
+
for name, model in models_to_merge_dict.items():
|
|
240
|
+
models_to_merge_layers_dict[name] = self.get_layers(model)
|
|
241
|
+
|
|
242
|
+
param_names_to_merge = None
|
|
243
|
+
for layer_idx, backbone_layer in tqdm(enumerate(backbone_layers),
|
|
244
|
+
desc="merging layers",
|
|
245
|
+
total=num_layers):
|
|
246
|
+
# dictionary of list, where key is the parameter name,
|
|
247
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
248
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
249
|
+
|
|
250
|
+
# list of dictionaries with length len(models_to_merge),
|
|
251
|
+
# each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
|
|
252
|
+
models_to_merge_regmean_weights_list = []
|
|
253
|
+
|
|
254
|
+
for name, layers_to_merge in models_to_merge_layers_dict.items():
|
|
255
|
+
layer_to_merge = layers_to_merge[layer_idx]
|
|
256
|
+
param_dict = layer_to_merge.state_dict()
|
|
257
|
+
|
|
258
|
+
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
259
|
+
if param_names_to_merge is None:
|
|
260
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
261
|
+
input_param_names=list(param_dict.keys()),
|
|
262
|
+
exclude_param_names_regex=self.config.get(
|
|
263
|
+
"exclude_param_names_regex", []
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
for param_name in param_names_to_merge:
|
|
268
|
+
models_to_merge_param_dict[param_name].append(
|
|
269
|
+
param_dict[param_name]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
linear_modules_to_merge = get_modules_to_merge(
|
|
273
|
+
model=layer_to_merge, include_module_types=self._include_module_type
|
|
274
|
+
)
|
|
275
|
+
assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
|
|
276
|
+
|
|
277
|
+
# 2.1. compute regmean weights for each model
|
|
278
|
+
with (
|
|
279
|
+
self.profile("merging models"),
|
|
280
|
+
self.profile("computing regmean weights"),
|
|
281
|
+
):
|
|
282
|
+
regmean_weights = self.get_regmean_weights(
|
|
283
|
+
name,
|
|
284
|
+
layer_to_merge,
|
|
285
|
+
batches_input=batches_input_dict[name],
|
|
286
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
module_subset = get_param_names_to_merge(
|
|
290
|
+
input_param_names=list(param_dict.keys()),
|
|
291
|
+
exclude_param_names_regex=self.exclude_param_names_regex
|
|
292
|
+
)
|
|
293
|
+
module_subset = [name.replace(".weight", "").replace(".bias", "") for name in module_subset]
|
|
294
|
+
module_subset = list(set(module_subset))
|
|
295
|
+
regmean_weights = {module_name: regmean_weights[module_name] for module_name in module_subset if module_name in regmean_weights}
|
|
296
|
+
|
|
297
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
298
|
+
|
|
299
|
+
# 2.2. merge parameters with regmean weights
|
|
300
|
+
with self.profile("merging models"):
|
|
301
|
+
# merging with regmean weights
|
|
302
|
+
merged_layer_params = merging_with_regmean_weights(
|
|
303
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
304
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
305
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
306
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
merged_params_dict = self.update_merged_params_dict(
|
|
310
|
+
merged_params_dict=merged_params_dict,
|
|
311
|
+
new_merged_params=merged_layer_params,
|
|
312
|
+
layer_idx=layer_idx,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# 2.3. compute input for the next layer
|
|
316
|
+
with (
|
|
317
|
+
self.profile("merging models"),
|
|
318
|
+
self.profile("forwarding next layer"),
|
|
319
|
+
):
|
|
320
|
+
if layer_idx < num_layers - 1:
|
|
321
|
+
backbone_layer.load_state_dict(merged_layer_params, strict=False)
|
|
322
|
+
batches_output_dict = defaultdict(list)
|
|
323
|
+
for name in models_to_merge_dict.keys():
|
|
324
|
+
batches_output_dict[name] = self.layer_batches_forward(
|
|
325
|
+
backbone_layer,
|
|
326
|
+
batches_input_dict[name]
|
|
327
|
+
)
|
|
328
|
+
batches_input_dict = batches_output_dict
|
|
329
|
+
|
|
330
|
+
# 3. load state dict to the merged model
|
|
331
|
+
merged_model.load_state_dict(merged_params_dict, strict=False)
|
|
332
|
+
|
|
333
|
+
self.print_profile_summary()
|
|
334
|
+
return merged_model
|
|
335
|
+
|
|
336
|
+
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
|
|
337
|
+
"""
|
|
338
|
+
Merge the embedding layer of the model with the merged model.
|
|
339
|
+
This method should be implemented in subclasses if needed.
|
|
340
|
+
"""
|
|
341
|
+
raise NotImplementedError()
|
|
342
|
+
|
|
343
|
+
def get_input_for_first_layer(self, model: nn.Module, train_dataset):
|
|
344
|
+
raise NotImplementedError
|
|
345
|
+
|
|
346
|
+
def get_layers(self, model: nn.Module):
|
|
347
|
+
raise NotImplementedError
|
|
348
|
+
|
|
349
|
+
def update_merged_params_dict(self, merged_params_dict, new_merged_params, layer_idx):
|
|
350
|
+
raise NotImplementedError
|
|
351
|
+
|
|
352
|
+
def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
|
|
353
|
+
raise NotImplementedError
|
|
354
|
+
|
|
355
|
+
def on_regmean_start(self):
|
|
356
|
+
pass
|
|
357
|
+
|
|
358
|
+
def get_regmean_weights(
|
|
359
|
+
self,
|
|
360
|
+
model_name: str,
|
|
361
|
+
layer: nn.Module,
|
|
362
|
+
batches_input: List[Tensor],
|
|
363
|
+
linear_modules_to_merge: Dict[str, nn.Module],
|
|
364
|
+
):
|
|
365
|
+
raise NotImplementedError
|
|
@@ -63,6 +63,18 @@ class SimpleAverageAlgorithm(
|
|
|
63
63
|
BaseAlgorithm,
|
|
64
64
|
SimpleProfilerMixin,
|
|
65
65
|
):
|
|
66
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
67
|
+
"show_pbar": "show_pbar",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
def __init__(self, show_pbar: bool = False):
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.show_pbar = show_pbar
|
|
77
|
+
|
|
66
78
|
@torch.no_grad()
|
|
67
79
|
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
68
80
|
"""
|
|
@@ -100,10 +112,14 @@ class SimpleAverageAlgorithm(
|
|
|
100
112
|
forward_model = model
|
|
101
113
|
else:
|
|
102
114
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
103
|
-
sd = state_dict_add(
|
|
115
|
+
sd = state_dict_add(
|
|
116
|
+
sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
|
|
117
|
+
)
|
|
104
118
|
with self.profile("merge weights"):
|
|
105
119
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
106
|
-
sd = state_dict_div(
|
|
120
|
+
sd = state_dict_div(
|
|
121
|
+
sd, len(modelpool.model_names), show_pbar=self.show_pbar
|
|
122
|
+
)
|
|
107
123
|
|
|
108
124
|
if isinstance(forward_model, LazyStateDict):
|
|
109
125
|
# if the model is a LazyStateDict, convert it to an empty module
|