fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
27
27
|
|
|
28
28
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
29
|
super().__init__(**kwargs)
|
|
30
|
-
self.
|
|
30
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
31
31
|
|
|
32
32
|
def on_regmean_start(self):
|
|
33
33
|
self.setup_zero_shot_classification_head()
|
|
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
60
60
|
# setup dataloader
|
|
61
61
|
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
62
62
|
train_dataloader = DataLoader(
|
|
63
|
-
train_dataset, shuffle=True, **self.
|
|
63
|
+
train_dataset, shuffle=True, **self.dataloader_kwargs
|
|
64
64
|
)
|
|
65
65
|
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
66
66
|
model = self.fabric.setup(model)
|
|
@@ -15,7 +15,7 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
|
15
15
|
from transformers.data import default_data_collator
|
|
16
16
|
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
17
|
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
19
19
|
from fusion_bench.utils import timeit_context
|
|
20
20
|
|
|
21
21
|
from .regmean import RegMeanAlgorithm
|
|
@@ -23,22 +23,15 @@ from .regmean import RegMeanAlgorithm
|
|
|
23
23
|
log = logging.getLogger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
+
@auto_register_config
|
|
26
27
|
class RegMeanAlgorithmForGPT2(
|
|
27
|
-
RegMeanAlgorithm,
|
|
28
28
|
LightningFabricMixin,
|
|
29
|
+
RegMeanAlgorithm,
|
|
29
30
|
):
|
|
30
31
|
_include_module_type = [Conv1D]
|
|
31
32
|
classifiers = {}
|
|
32
|
-
_config_mapping = RegMeanAlgorithm._config_mapping | {
|
|
33
|
-
"cache_dir": "cache_dir",
|
|
34
|
-
"batch_size": "batch_size",
|
|
35
|
-
"num_workers": "num_workers",
|
|
36
|
-
}
|
|
37
33
|
|
|
38
34
|
def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
|
|
39
|
-
self.cache_dir = cache_dir
|
|
40
|
-
self.batch_size = batch_size
|
|
41
|
-
self.num_workers = num_workers
|
|
42
35
|
super().__init__(**kwargs)
|
|
43
36
|
|
|
44
37
|
def on_regmean_start(self):
|
|
@@ -13,7 +13,7 @@ from torch import Tensor, nn
|
|
|
13
13
|
from tqdm.autonotebook import tqdm
|
|
14
14
|
|
|
15
15
|
from fusion_bench.method import BaseAlgorithm
|
|
16
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
17
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
18
18
|
|
|
19
19
|
log = logging.getLogger(__name__)
|
|
@@ -280,14 +280,9 @@ def regmean_merging(
|
|
|
280
280
|
return merged_params
|
|
281
281
|
|
|
282
282
|
|
|
283
|
+
@auto_register_config
|
|
283
284
|
class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
284
285
|
_include_module_type = [nn.Linear]
|
|
285
|
-
_config_mapping = {
|
|
286
|
-
"num_regmean_examples": "num_regmean_examples",
|
|
287
|
-
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
288
|
-
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
289
|
-
"weight_transpose": "weight_transpose",
|
|
290
|
-
}
|
|
291
286
|
|
|
292
287
|
def __init__(
|
|
293
288
|
self,
|
|
@@ -298,10 +293,6 @@ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
298
293
|
weight_transpose: bool,
|
|
299
294
|
**kwargs,
|
|
300
295
|
):
|
|
301
|
-
self.num_regmean_examples = num_regmean_examples
|
|
302
|
-
self.exclude_param_names_regex = exclude_param_names_regex
|
|
303
|
-
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
304
|
-
self.weight_transpose = weight_transpose
|
|
305
296
|
super().__init__(**kwargs)
|
|
306
297
|
|
|
307
298
|
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import Dict, List, cast # noqa: F401
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.utils.data
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from torch.nn.modules import Module
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from tqdm.autonotebook import tqdm
|
|
12
|
+
|
|
13
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
14
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
15
|
+
|
|
16
|
+
from .regmean_plusplus import RegMeanAlgorithmPlusPlus
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RegMeanAlgorithmForCLIPPlusPlus(
|
|
22
|
+
RegMeanAlgorithmPlusPlus,
|
|
23
|
+
CLIPClassificationMixin,
|
|
24
|
+
):
|
|
25
|
+
_config_mapping = {
|
|
26
|
+
"_dataloader_kwargs": "dataloader_kwargs",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
30
|
+
super().__init__(**kwargs)
|
|
31
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
32
|
+
|
|
33
|
+
def on_regmean_start(self):
|
|
34
|
+
self.setup_zero_shot_classification_head()
|
|
35
|
+
|
|
36
|
+
def compute_logits(self, module, batch, task: str) -> Tensor:
|
|
37
|
+
images, _ = batch
|
|
38
|
+
text_embeds = self.zeroshot_weights[task]
|
|
39
|
+
|
|
40
|
+
image_embeds = module(images)[1]
|
|
41
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
42
|
+
|
|
43
|
+
# normalize embeddings
|
|
44
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
45
|
+
|
|
46
|
+
# cosine similarity
|
|
47
|
+
logits_per_text = (
|
|
48
|
+
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale_exp
|
|
49
|
+
)
|
|
50
|
+
logits_per_image = logits_per_text.t()
|
|
51
|
+
|
|
52
|
+
return logits_per_image
|
|
53
|
+
|
|
54
|
+
def get_regmean_weights(
|
|
55
|
+
self,
|
|
56
|
+
model_name: str,
|
|
57
|
+
layer: Module,
|
|
58
|
+
batches_input: List[Tensor],
|
|
59
|
+
linear_modules_to_merge: Dict[str, Module],
|
|
60
|
+
):
|
|
61
|
+
layer = self.fabric.setup(layer)
|
|
62
|
+
|
|
63
|
+
def compute_regmean_weights(module_name: str):
|
|
64
|
+
"""
|
|
65
|
+
compute the regmean weights, a hook function to deal with each module's input
|
|
66
|
+
:param module_name: str, module name
|
|
67
|
+
:return:
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def hook(module: nn.Module, input: tuple, output: torch.Tensor):
|
|
71
|
+
# Tensor, shape (batch_size, sequence_length, hidden_dim)
|
|
72
|
+
x = cast(Tensor, input[0]).detach()
|
|
73
|
+
batch_num_actual_examples = x.shape[0]
|
|
74
|
+
# Tensor, shape (batch_size * sequence_length, hidden_dim)
|
|
75
|
+
x = x.reshape(-1, x.shape[-1])
|
|
76
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
77
|
+
xtx = torch.matmul(x.transpose(0, 1), x)
|
|
78
|
+
# store the averaged weights in regmean_weights
|
|
79
|
+
if module_name not in regmean_weights.keys():
|
|
80
|
+
regmean_weights[module_name] = xtx / x.shape[0]
|
|
81
|
+
num_computed_examples[module_name] = x.shape[0]
|
|
82
|
+
num_actual_examples[module_name] = batch_num_actual_examples
|
|
83
|
+
else:
|
|
84
|
+
regmean_weights[module_name] = (
|
|
85
|
+
regmean_weights[module_name]
|
|
86
|
+
* num_computed_examples[module_name]
|
|
87
|
+
+ xtx
|
|
88
|
+
) / (num_computed_examples[module_name] + x.shape[0])
|
|
89
|
+
num_computed_examples[module_name] += x.shape[0]
|
|
90
|
+
num_actual_examples[module_name] += batch_num_actual_examples
|
|
91
|
+
|
|
92
|
+
return hook
|
|
93
|
+
|
|
94
|
+
handles = []
|
|
95
|
+
# dictionary, regmean matrices for each linear module inputs
|
|
96
|
+
regmean_weights = {}
|
|
97
|
+
# dictionary, number of examples (multiplied the sequence length) used for computing regmean matrices
|
|
98
|
+
num_computed_examples = {}
|
|
99
|
+
# dictionary, number of actual examples used for computing regmean matrices
|
|
100
|
+
num_actual_examples = {}
|
|
101
|
+
|
|
102
|
+
for module_name, linear_module_to_merge in linear_modules_to_merge.items():
|
|
103
|
+
# register a hook in the forward process
|
|
104
|
+
handle = linear_module_to_merge.register_forward_hook(
|
|
105
|
+
compute_regmean_weights(module_name=module_name)
|
|
106
|
+
)
|
|
107
|
+
handles.append(handle)
|
|
108
|
+
_ = self.layer_batches_forward(layer, batches_input)
|
|
109
|
+
|
|
110
|
+
# remove the added hook
|
|
111
|
+
for handle in handles:
|
|
112
|
+
handle.remove()
|
|
113
|
+
|
|
114
|
+
for module_name in regmean_weights.keys():
|
|
115
|
+
regmean_weights[module_name] = regmean_weights[module_name].detach().cpu()
|
|
116
|
+
|
|
117
|
+
return regmean_weights
|
|
118
|
+
|
|
119
|
+
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
|
|
120
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
121
|
+
|
|
122
|
+
# get the parameters of the embedding layer from each model
|
|
123
|
+
for model_to_merge in models_to_merge_dict.values():
|
|
124
|
+
model_to_merge_state_dict = model_to_merge.state_dict()
|
|
125
|
+
|
|
126
|
+
param_dict = {}
|
|
127
|
+
for name, param in model_to_merge_state_dict.items():
|
|
128
|
+
if name.startswith("vision_model.embeddings") or name.startswith(
|
|
129
|
+
"vision_model.pre_layrnorm"
|
|
130
|
+
):
|
|
131
|
+
param_dict[name] = param
|
|
132
|
+
|
|
133
|
+
for param_name in param_dict.keys():
|
|
134
|
+
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
135
|
+
|
|
136
|
+
# merge the parameters of the embedding layer
|
|
137
|
+
merged_params_dict = {}
|
|
138
|
+
for param_name, param_list in models_to_merge_param_dict.items():
|
|
139
|
+
merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
|
|
140
|
+
|
|
141
|
+
return merged_params_dict
|
|
142
|
+
|
|
143
|
+
def get_input_for_first_layer(self, model: nn.Module, train_dataset):
|
|
144
|
+
# setup dataloader
|
|
145
|
+
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
146
|
+
train_dataloader = DataLoader(
|
|
147
|
+
train_dataset, shuffle=True, **self.dataloader_kwargs
|
|
148
|
+
)
|
|
149
|
+
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
150
|
+
model = self.fabric.setup(model)
|
|
151
|
+
|
|
152
|
+
def compute_input(model, batch):
|
|
153
|
+
images, _ = batch
|
|
154
|
+
|
|
155
|
+
images = images.to(model.device)
|
|
156
|
+
image_embeds = model.vision_model.embeddings(images)
|
|
157
|
+
image_embeds = model.vision_model.pre_layrnorm(image_embeds)
|
|
158
|
+
image_embeds = image_embeds.detach().cpu()
|
|
159
|
+
|
|
160
|
+
return image_embeds
|
|
161
|
+
|
|
162
|
+
num_computed_examples = 0
|
|
163
|
+
num_regmean_examples = self.num_regmean_examples
|
|
164
|
+
|
|
165
|
+
batches_input = []
|
|
166
|
+
for batch in train_dataloader:
|
|
167
|
+
if num_computed_examples >= num_regmean_examples:
|
|
168
|
+
break
|
|
169
|
+
batches_input.append(compute_input(model, batch))
|
|
170
|
+
num_computed_examples += batch[0].size(0)
|
|
171
|
+
|
|
172
|
+
return batches_input
|
|
173
|
+
|
|
174
|
+
def get_layers(self, model: nn.Module):
|
|
175
|
+
return model.vision_model.encoder.layers
|
|
176
|
+
|
|
177
|
+
def update_merged_params_dict(
|
|
178
|
+
self, merged_params_dict, new_merged_params, layer_idx
|
|
179
|
+
):
|
|
180
|
+
for key, value in new_merged_params.items():
|
|
181
|
+
key = f"vision_model.encoder.layers.{layer_idx}.{key}"
|
|
182
|
+
merged_params_dict[key] = value
|
|
183
|
+
|
|
184
|
+
return merged_params_dict
|
|
185
|
+
|
|
186
|
+
def layer_batches_forward(
|
|
187
|
+
self, layer: nn.Module, batches_input: List[Tensor]
|
|
188
|
+
) -> Tensor:
|
|
189
|
+
batches_output = []
|
|
190
|
+
for batch in batches_input:
|
|
191
|
+
device = next(layer.parameters()).device
|
|
192
|
+
batch = batch.to(device)
|
|
193
|
+
logits = (
|
|
194
|
+
layer(batch, attention_mask=None, causal_attention_mask=None)[0]
|
|
195
|
+
.detach()
|
|
196
|
+
.cpu()
|
|
197
|
+
)
|
|
198
|
+
batches_output.append(logits)
|
|
199
|
+
return batches_output
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import Tensor, nn
|
|
8
|
+
from tqdm.autonotebook import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.method import BaseAlgorithm
|
|
11
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_param_names_to_merge(
|
|
18
|
+
input_param_names: List[str], exclude_param_names_regex: list
|
|
19
|
+
):
|
|
20
|
+
"""
|
|
21
|
+
get the names of parameters that need to be merged
|
|
22
|
+
:param input_param_names: list, names of input parameters
|
|
23
|
+
:param exclude_param_names_regex: list, regular expression of names of parameters that need to be excluded
|
|
24
|
+
:return:
|
|
25
|
+
"""
|
|
26
|
+
param_names_to_merge = []
|
|
27
|
+
for param_name in input_param_names:
|
|
28
|
+
exclude = any(
|
|
29
|
+
[
|
|
30
|
+
re.match(exclude_pattern, param_name)
|
|
31
|
+
for exclude_pattern in exclude_param_names_regex
|
|
32
|
+
]
|
|
33
|
+
)
|
|
34
|
+
if not exclude:
|
|
35
|
+
param_names_to_merge.append(param_name)
|
|
36
|
+
return param_names_to_merge
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_modules_to_merge(model: nn.Module, include_module_types: list):
|
|
40
|
+
"""
|
|
41
|
+
get the model modules that need to be merged, whose type is in include_module_types
|
|
42
|
+
:param model: nn.Module, input model
|
|
43
|
+
:param include_module_types: list, module types that want to include
|
|
44
|
+
:return:
|
|
45
|
+
"""
|
|
46
|
+
modules_to_merge: Dict[str, nn.Module] = {}
|
|
47
|
+
for module_name, module in model.named_modules():
|
|
48
|
+
is_valid_type = not include_module_types or any(
|
|
49
|
+
[
|
|
50
|
+
isinstance(module, include_module_type)
|
|
51
|
+
for include_module_type in include_module_types
|
|
52
|
+
]
|
|
53
|
+
)
|
|
54
|
+
if is_valid_type:
|
|
55
|
+
modules_to_merge[module_name] = module
|
|
56
|
+
return modules_to_merge
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def reduce_non_diagonal_elements(
|
|
60
|
+
regmean_weights: torch.Tensor, reduce_non_diagonal_ratio: float
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
reduce the non-diagonal elements in regmean_weights
|
|
64
|
+
:param regmean_weights: Tensor, shape (hidden_dim, hidden_dim), input regmean weights
|
|
65
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
66
|
+
:return:
|
|
67
|
+
"""
|
|
68
|
+
# diagonal matrix with (1 - reduce_non_diagonal_ratio) as elements
|
|
69
|
+
diag_weights = torch.diag(
|
|
70
|
+
torch.ones(regmean_weights.shape[0]) - reduce_non_diagonal_ratio
|
|
71
|
+
).to(regmean_weights.device)
|
|
72
|
+
# matrix with reduce_non_diagonal_ratio as elements
|
|
73
|
+
non_diag_weights = torch.zeros_like(diag_weights).fill_(reduce_non_diagonal_ratio)
|
|
74
|
+
# diagonal elements are unchanged, while non-diagonal elements are multiplied by reduce_non_diagonal_ratio
|
|
75
|
+
return regmean_weights * (diag_weights + non_diag_weights)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def regmean_params_merge(
|
|
79
|
+
param_weight_list: List[Tensor],
|
|
80
|
+
param_regmean_list: List[Tensor],
|
|
81
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
82
|
+
weight_transpose: bool = True,
|
|
83
|
+
module_name: str = "",
|
|
84
|
+
device="cpu",
|
|
85
|
+
):
|
|
86
|
+
# two lists with length num_models_to_merge
|
|
87
|
+
param_multiplied_results, module_regmean_weights_list = [], []
|
|
88
|
+
for model_idx, module_regmean_weights in enumerate(param_regmean_list):
|
|
89
|
+
# reduce non-diagonal elements
|
|
90
|
+
module_regmean_weights = reduce_non_diagonal_elements(
|
|
91
|
+
regmean_weights=module_regmean_weights,
|
|
92
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
93
|
+
)
|
|
94
|
+
module_regmean_weights_list.append(module_regmean_weights)
|
|
95
|
+
|
|
96
|
+
model_to_merge_param = param_weight_list[model_idx]
|
|
97
|
+
# since the weight shape of Linear module is (output_size, input_size), we need to transpose it
|
|
98
|
+
param_multiplied_results.append(
|
|
99
|
+
torch.matmul(
|
|
100
|
+
module_regmean_weights,
|
|
101
|
+
(
|
|
102
|
+
model_to_merge_param.transpose(0, 1)
|
|
103
|
+
if weight_transpose
|
|
104
|
+
else model_to_merge_param
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# sum up module_regmean_weights and param_multiplied_results over all individual models
|
|
110
|
+
sum_module_regmean_weights = sum(module_regmean_weights_list)
|
|
111
|
+
sum_param_multiplied_results = sum(param_multiplied_results)
|
|
112
|
+
|
|
113
|
+
# get the inverse matrix
|
|
114
|
+
inv_sum_module_regmean_weights = torch.inverse(sum_module_regmean_weights)
|
|
115
|
+
# merge parameters with regmean
|
|
116
|
+
merged_param = torch.matmul(
|
|
117
|
+
inv_sum_module_regmean_weights, sum_param_multiplied_results
|
|
118
|
+
)
|
|
119
|
+
# transpose to the original shape of "weight" in Linear module
|
|
120
|
+
merged_param = merged_param.transpose(0, 1) if weight_transpose else merged_param
|
|
121
|
+
|
|
122
|
+
return merged_param
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def merging_with_regmean_weights(
|
|
126
|
+
models_to_merge_param_dict: dict,
|
|
127
|
+
models_to_merge_regmean_weights_list: list,
|
|
128
|
+
reduce_non_diagonal_ratio: float = 1.0,
|
|
129
|
+
weight_transpose: bool = True,
|
|
130
|
+
):
|
|
131
|
+
"""
|
|
132
|
+
merge parameters of different models with computed regmean weights
|
|
133
|
+
:param models_to_merge_param_dict: dict, dictionary of list, where key is the parameter name,
|
|
134
|
+
value is a list of the corresponding parameters of all the models that need to be merged
|
|
135
|
+
:param models_to_merge_regmean_weights_list: list, list of dictionaries with length len(models_to_merge),
|
|
136
|
+
each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged, key is module name
|
|
137
|
+
:param reduce_non_diagonal_ratio: float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
138
|
+
:return:
|
|
139
|
+
"""
|
|
140
|
+
# dict, dictionary of model parameters
|
|
141
|
+
merged_params = {}
|
|
142
|
+
|
|
143
|
+
for param_name, param_value_list in models_to_merge_param_dict.items():
|
|
144
|
+
merged_by_regmean = False
|
|
145
|
+
# only perform regmean merging on the "weight" parameter of Linear module
|
|
146
|
+
if param_name.endswith(".weight"):
|
|
147
|
+
module_name = param_name[: -len(".weight")]
|
|
148
|
+
if module_name in models_to_merge_regmean_weights_list[0].keys():
|
|
149
|
+
# two lists with length num_models_to_merge
|
|
150
|
+
module_regmean_weights_list = []
|
|
151
|
+
for model_idx, model_to_merge_regmean_weights in enumerate(
|
|
152
|
+
models_to_merge_regmean_weights_list
|
|
153
|
+
):
|
|
154
|
+
device = param_value_list[model_idx].device
|
|
155
|
+
|
|
156
|
+
# Tensor, shape (hidden_dim, hidden_dim)
|
|
157
|
+
module_regmean_weights = model_to_merge_regmean_weights[
|
|
158
|
+
module_name
|
|
159
|
+
].to(device)
|
|
160
|
+
module_regmean_weights_list.append(module_regmean_weights)
|
|
161
|
+
|
|
162
|
+
merged_params[param_name] = regmean_params_merge(
|
|
163
|
+
param_weight_list=param_value_list,
|
|
164
|
+
param_regmean_list=module_regmean_weights_list,
|
|
165
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
166
|
+
weight_transpose=weight_transpose,
|
|
167
|
+
module_name=module_name,
|
|
168
|
+
device=device,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
merged_by_regmean = True
|
|
172
|
+
# use average merging for parameters whose names are not end with ".weight" or not in Linear module
|
|
173
|
+
if not merged_by_regmean:
|
|
174
|
+
merged_params[param_name] = torch.stack(param_value_list, dim=0).mean(dim=0)
|
|
175
|
+
|
|
176
|
+
return merged_params
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
180
|
+
_include_module_type = [nn.Linear]
|
|
181
|
+
_config_mapping = {
|
|
182
|
+
"num_regmean_examples": "num_regmean_examples",
|
|
183
|
+
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
184
|
+
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
185
|
+
"weight_transpose": "weight_transpose",
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
*,
|
|
191
|
+
num_regmean_examples: int,
|
|
192
|
+
exclude_param_names_regex: list,
|
|
193
|
+
reduce_non_diagonal_ratio: float,
|
|
194
|
+
weight_transpose: bool,
|
|
195
|
+
**kwargs,
|
|
196
|
+
):
|
|
197
|
+
self.num_regmean_examples = num_regmean_examples
|
|
198
|
+
self.exclude_param_names_regex = exclude_param_names_regex
|
|
199
|
+
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
200
|
+
self.weight_transpose = weight_transpose
|
|
201
|
+
super().__init__(**kwargs)
|
|
202
|
+
|
|
203
|
+
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
204
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
205
|
+
modelpool = BaseModelPool(modelpool)
|
|
206
|
+
self.modelpool = modelpool
|
|
207
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
208
|
+
models_to_merge_dict = {
|
|
209
|
+
name: model.to(device) for name, model in modelpool.named_models()
|
|
210
|
+
}
|
|
211
|
+
self.on_regmean_start()
|
|
212
|
+
|
|
213
|
+
# initialize the merged models as the pretrained model
|
|
214
|
+
merged_model = modelpool.load_pretrained_model().to(device)
|
|
215
|
+
merged_params_dict = {}
|
|
216
|
+
|
|
217
|
+
# 1. merge embedding layer
|
|
218
|
+
merged_embedding_dict = self.merge_embedding_layer(
|
|
219
|
+
models_to_merge_dict=models_to_merge_dict
|
|
220
|
+
)
|
|
221
|
+
merged_model.load_state_dict(merged_embedding_dict, strict=False)
|
|
222
|
+
|
|
223
|
+
with torch.no_grad():
|
|
224
|
+
# 1.1. compute input for the first layer
|
|
225
|
+
with (
|
|
226
|
+
self.profile("merging models"),
|
|
227
|
+
self.profile("computing first layer input"),
|
|
228
|
+
):
|
|
229
|
+
batches_input_dict = defaultdict(list)
|
|
230
|
+
for name in tqdm(
|
|
231
|
+
models_to_merge_dict.keys(), desc="computing input for first layer"
|
|
232
|
+
):
|
|
233
|
+
dataset = modelpool.load_train_dataset(name)
|
|
234
|
+
|
|
235
|
+
batches_input_dict[name] = self.get_input_for_first_layer(
|
|
236
|
+
merged_model, dataset
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# 2. iteratively merge layer by layer with regmean algorithm
|
|
240
|
+
backbone_layers = self.get_layers(merged_model)
|
|
241
|
+
num_layers = len(backbone_layers)
|
|
242
|
+
|
|
243
|
+
models_to_merge_layers_dict = defaultdict(list)
|
|
244
|
+
for name, model in models_to_merge_dict.items():
|
|
245
|
+
models_to_merge_layers_dict[name] = self.get_layers(model)
|
|
246
|
+
|
|
247
|
+
param_names_to_merge = None
|
|
248
|
+
for layer_idx, backbone_layer in tqdm(
|
|
249
|
+
enumerate(backbone_layers), desc="merging layers", total=num_layers
|
|
250
|
+
):
|
|
251
|
+
# dictionary of list, where key is the parameter name,
|
|
252
|
+
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
253
|
+
models_to_merge_param_dict = defaultdict(list)
|
|
254
|
+
|
|
255
|
+
# list of dictionaries with length len(models_to_merge),
|
|
256
|
+
# each dictionary records the regmean weights (matrix) of parameters for each model that needs to be merged
|
|
257
|
+
models_to_merge_regmean_weights_list = []
|
|
258
|
+
|
|
259
|
+
for name, layers_to_merge in models_to_merge_layers_dict.items():
|
|
260
|
+
layer_to_merge = layers_to_merge[layer_idx]
|
|
261
|
+
param_dict = layer_to_merge.state_dict()
|
|
262
|
+
|
|
263
|
+
# exclude parameter whose name matches element in exclude_param_names_regex
|
|
264
|
+
if param_names_to_merge is None:
|
|
265
|
+
param_names_to_merge = get_param_names_to_merge(
|
|
266
|
+
input_param_names=list(param_dict.keys()),
|
|
267
|
+
exclude_param_names_regex=self.config.get(
|
|
268
|
+
"exclude_param_names_regex", []
|
|
269
|
+
),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
for param_name in param_names_to_merge:
|
|
273
|
+
models_to_merge_param_dict[param_name].append(
|
|
274
|
+
param_dict[param_name]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
linear_modules_to_merge = get_modules_to_merge(
|
|
278
|
+
model=layer_to_merge,
|
|
279
|
+
include_module_types=self._include_module_type,
|
|
280
|
+
)
|
|
281
|
+
assert (
|
|
282
|
+
len(linear_modules_to_merge) > 0
|
|
283
|
+
), "No linear modules to merge"
|
|
284
|
+
|
|
285
|
+
# 2.1. compute regmean weights for each model
|
|
286
|
+
with (
|
|
287
|
+
self.profile("merging models"),
|
|
288
|
+
self.profile("computing regmean weights"),
|
|
289
|
+
):
|
|
290
|
+
regmean_weights = self.get_regmean_weights(
|
|
291
|
+
name,
|
|
292
|
+
layer_to_merge,
|
|
293
|
+
batches_input=batches_input_dict[name],
|
|
294
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
module_subset = get_param_names_to_merge(
|
|
298
|
+
input_param_names=list(param_dict.keys()),
|
|
299
|
+
exclude_param_names_regex=self.exclude_param_names_regex,
|
|
300
|
+
)
|
|
301
|
+
module_subset = [
|
|
302
|
+
name.replace(".weight", "").replace(".bias", "")
|
|
303
|
+
for name in module_subset
|
|
304
|
+
]
|
|
305
|
+
module_subset = list(set(module_subset))
|
|
306
|
+
regmean_weights = {
|
|
307
|
+
module_name: regmean_weights[module_name]
|
|
308
|
+
for module_name in module_subset
|
|
309
|
+
if module_name in regmean_weights
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
313
|
+
|
|
314
|
+
# 2.2. merge parameters with regmean weights
|
|
315
|
+
with self.profile("merging models"):
|
|
316
|
+
# merging with regmean weights
|
|
317
|
+
merged_layer_params = merging_with_regmean_weights(
|
|
318
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
319
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
320
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
321
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
merged_params_dict = self.update_merged_params_dict(
|
|
325
|
+
merged_params_dict=merged_params_dict,
|
|
326
|
+
new_merged_params=merged_layer_params,
|
|
327
|
+
layer_idx=layer_idx,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# 2.3. compute input for the next layer
|
|
331
|
+
with (
|
|
332
|
+
self.profile("merging models"),
|
|
333
|
+
self.profile("forwarding next layer"),
|
|
334
|
+
):
|
|
335
|
+
if layer_idx < num_layers - 1:
|
|
336
|
+
backbone_layer.load_state_dict(
|
|
337
|
+
merged_layer_params, strict=False
|
|
338
|
+
)
|
|
339
|
+
batches_output_dict = defaultdict(list)
|
|
340
|
+
for name in models_to_merge_dict.keys():
|
|
341
|
+
batches_output_dict[name] = self.layer_batches_forward(
|
|
342
|
+
backbone_layer, batches_input_dict[name]
|
|
343
|
+
)
|
|
344
|
+
batches_input_dict = batches_output_dict
|
|
345
|
+
|
|
346
|
+
# 3. load state dict to the merged model
|
|
347
|
+
merged_model.load_state_dict(merged_params_dict, strict=False)
|
|
348
|
+
|
|
349
|
+
self.print_profile_summary()
|
|
350
|
+
return merged_model
|
|
351
|
+
|
|
352
|
+
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
|
|
353
|
+
"""
|
|
354
|
+
Merge the embedding layer of the model with the merged model.
|
|
355
|
+
This method should be implemented in subclasses if needed.
|
|
356
|
+
"""
|
|
357
|
+
raise NotImplementedError()
|
|
358
|
+
|
|
359
|
+
def get_input_for_first_layer(self, model: nn.Module, train_dataset):
|
|
360
|
+
raise NotImplementedError
|
|
361
|
+
|
|
362
|
+
def get_layers(self, model: nn.Module):
|
|
363
|
+
raise NotImplementedError
|
|
364
|
+
|
|
365
|
+
def update_merged_params_dict(
|
|
366
|
+
self, merged_params_dict, new_merged_params, layer_idx
|
|
367
|
+
):
|
|
368
|
+
raise NotImplementedError
|
|
369
|
+
|
|
370
|
+
def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
|
|
371
|
+
raise NotImplementedError
|
|
372
|
+
|
|
373
|
+
def on_regmean_start(self):
|
|
374
|
+
pass
|
|
375
|
+
|
|
376
|
+
def get_regmean_weights(
|
|
377
|
+
self,
|
|
378
|
+
model_name: str,
|
|
379
|
+
layer: nn.Module,
|
|
380
|
+
batches_input: List[Tensor],
|
|
381
|
+
linear_modules_to_merge: Dict[str, nn.Module],
|
|
382
|
+
):
|
|
383
|
+
raise NotImplementedError
|