fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -16,14 +16,18 @@ from transformers import CLIPVisionModel
|
|
|
16
16
|
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
|
17
17
|
from typing_extensions import override
|
|
18
18
|
|
|
19
|
-
from fusion_bench
|
|
19
|
+
from fusion_bench import (
|
|
20
|
+
BaseAlgorithm,
|
|
21
|
+
auto_register_config,
|
|
22
|
+
print_parameters,
|
|
23
|
+
timeit_context,
|
|
24
|
+
)
|
|
25
|
+
from fusion_bench.dataset import CLIPDataset
|
|
20
26
|
from fusion_bench.method.task_arithmetic import task_arithmetic_merge
|
|
21
27
|
from fusion_bench.mixins.clip_classification import CLIPClassificationMixin
|
|
22
28
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
23
29
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
24
|
-
from fusion_bench.utils import timeit_context
|
|
25
30
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
-
from fusion_bench.utils.parameters import print_parameters
|
|
27
31
|
|
|
28
32
|
from .module import ParetoWeightEnsemblingModule
|
|
29
33
|
from .utils import generate_simplex_grid
|
|
@@ -31,27 +35,13 @@ from .utils import generate_simplex_grid
|
|
|
31
35
|
log = logging.getLogger(__name__)
|
|
32
36
|
|
|
33
37
|
|
|
38
|
+
@auto_register_config
|
|
34
39
|
class PWEMoEAlgorithmForCLIP(
|
|
35
40
|
BaseAlgorithm,
|
|
36
41
|
SimpleProfilerMixin,
|
|
37
42
|
CLIPClassificationMixin,
|
|
38
43
|
):
|
|
39
44
|
modelpool: CLIPVisionModelPool = None
|
|
40
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
41
|
-
"upscale_mlp": "upscale_mlp",
|
|
42
|
-
"upscale_attn": "upscale_attn",
|
|
43
|
-
"init_lambda": "init_lambda",
|
|
44
|
-
"router_hidden_layers": "router_hidden_layers",
|
|
45
|
-
"lr": "lr",
|
|
46
|
-
"num_steps": "num_steps",
|
|
47
|
-
"save_interval": "save_interval",
|
|
48
|
-
"alpha": "alpha",
|
|
49
|
-
"checkpoint_path": "checkpoint_path",
|
|
50
|
-
"eval_grid": "eval_grid",
|
|
51
|
-
"eval_grid_n": "eval_grid_n",
|
|
52
|
-
"eval_grid_m": "eval_grid_m",
|
|
53
|
-
"_dataloader_kwargs": "dataloader_kwargs",
|
|
54
|
-
}
|
|
55
45
|
|
|
56
46
|
def __init__(
|
|
57
47
|
self,
|
|
@@ -72,19 +62,6 @@ class PWEMoEAlgorithmForCLIP(
|
|
|
72
62
|
**kwargs,
|
|
73
63
|
):
|
|
74
64
|
super().__init__(**kwargs)
|
|
75
|
-
self.upscale_mlp = upscale_mlp
|
|
76
|
-
self.upscale_attn = upscale_attn
|
|
77
|
-
self.init_lambda = init_lambda
|
|
78
|
-
self.router_hidden_layers = router_hidden_layers
|
|
79
|
-
self.lr = lr
|
|
80
|
-
self.num_steps = num_steps
|
|
81
|
-
self.save_interval = save_interval
|
|
82
|
-
self.alpha = alpha
|
|
83
|
-
self.checkpoint_path = checkpoint_path
|
|
84
|
-
self.eval_grid = eval_grid
|
|
85
|
-
self.eval_grid_n = eval_grid_n
|
|
86
|
-
self.eval_grid_m = eval_grid_m
|
|
87
|
-
self._dataloader_kwargs = dataloader_kwargs
|
|
88
65
|
|
|
89
66
|
@override
|
|
90
67
|
def run(self, modelpool: CLIPVisionModelPool):
|
|
@@ -193,13 +170,14 @@ class PWEMoEAlgorithmForCLIP(
|
|
|
193
170
|
Loads the datasets specified in the configuration.
|
|
194
171
|
"""
|
|
195
172
|
train_datasets = {
|
|
196
|
-
dataset_name:
|
|
197
|
-
dataset_name,
|
|
173
|
+
dataset_name: CLIPDataset(
|
|
174
|
+
self.modelpool.load_train_dataset(dataset_name),
|
|
175
|
+
processor=self.clip_processor,
|
|
198
176
|
)
|
|
199
177
|
for dataset_name in self.modelpool.model_names
|
|
200
178
|
}
|
|
201
179
|
train_loaders = {
|
|
202
|
-
dataset_name: DataLoader(dataset, shuffle=True, **self.
|
|
180
|
+
dataset_name: DataLoader(dataset, shuffle=True, **self.dataloader_kwargs)
|
|
203
181
|
for dataset_name, dataset in train_datasets.items()
|
|
204
182
|
}
|
|
205
183
|
train_loaders = {
|
|
@@ -5,9 +5,7 @@ import torch
|
|
|
5
5
|
|
|
6
6
|
from fusion_bench.modelpool import BaseModelPool
|
|
7
7
|
from fusion_bench.utils.parameters import count_parameters
|
|
8
|
-
from fusion_bench.utils.state_dict_arithmetic import
|
|
9
|
-
state_dict_mul,
|
|
10
|
-
)
|
|
8
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_mul
|
|
11
9
|
|
|
12
10
|
from .base_algorithm import SuperposedAlgorithmBase, compare_models
|
|
13
11
|
|
|
@@ -27,7 +27,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
27
27
|
|
|
28
28
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
29
|
super().__init__(**kwargs)
|
|
30
|
-
self.
|
|
30
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
31
31
|
|
|
32
32
|
def on_regmean_start(self):
|
|
33
33
|
self.setup_zero_shot_classification_head()
|
|
@@ -60,7 +60,7 @@ class RegMeanAlgorithmForCLIP(
|
|
|
60
60
|
# setup dataloader
|
|
61
61
|
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
62
62
|
train_dataloader = DataLoader(
|
|
63
|
-
train_dataset, shuffle=True, **self.
|
|
63
|
+
train_dataset, shuffle=True, **self.dataloader_kwargs
|
|
64
64
|
)
|
|
65
65
|
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
66
66
|
model = self.fabric.setup(model)
|
|
@@ -15,7 +15,7 @@ from transformers import GPT2ForSequenceClassification, GPT2Model
|
|
|
15
15
|
from transformers.data import default_data_collator
|
|
16
16
|
from transformers.models.gpt2.modeling_gpt2 import Conv1D
|
|
17
17
|
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, auto_register_config
|
|
19
19
|
from fusion_bench.utils import timeit_context
|
|
20
20
|
|
|
21
21
|
from .regmean import RegMeanAlgorithm
|
|
@@ -23,22 +23,15 @@ from .regmean import RegMeanAlgorithm
|
|
|
23
23
|
log = logging.getLogger(__name__)
|
|
24
24
|
|
|
25
25
|
|
|
26
|
+
@auto_register_config
|
|
26
27
|
class RegMeanAlgorithmForGPT2(
|
|
27
|
-
RegMeanAlgorithm,
|
|
28
28
|
LightningFabricMixin,
|
|
29
|
+
RegMeanAlgorithm,
|
|
29
30
|
):
|
|
30
31
|
_include_module_type = [Conv1D]
|
|
31
32
|
classifiers = {}
|
|
32
|
-
_config_mapping = RegMeanAlgorithm._config_mapping | {
|
|
33
|
-
"cache_dir": "cache_dir",
|
|
34
|
-
"batch_size": "batch_size",
|
|
35
|
-
"num_workers": "num_workers",
|
|
36
|
-
}
|
|
37
33
|
|
|
38
34
|
def __init__(self, cache_dir: str, batch_size: int, num_workers: int, **kwargs):
|
|
39
|
-
self.cache_dir = cache_dir
|
|
40
|
-
self.batch_size = batch_size
|
|
41
|
-
self.num_workers = num_workers
|
|
42
35
|
super().__init__(**kwargs)
|
|
43
36
|
|
|
44
37
|
def on_regmean_start(self):
|
|
@@ -13,7 +13,7 @@ from torch import Tensor, nn
|
|
|
13
13
|
from tqdm.autonotebook import tqdm
|
|
14
14
|
|
|
15
15
|
from fusion_bench.method import BaseAlgorithm
|
|
16
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
17
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
18
18
|
|
|
19
19
|
log = logging.getLogger(__name__)
|
|
@@ -280,14 +280,9 @@ def regmean_merging(
|
|
|
280
280
|
return merged_params
|
|
281
281
|
|
|
282
282
|
|
|
283
|
+
@auto_register_config
|
|
283
284
|
class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
284
285
|
_include_module_type = [nn.Linear]
|
|
285
|
-
_config_mapping = {
|
|
286
|
-
"num_regmean_examples": "num_regmean_examples",
|
|
287
|
-
"exclude_param_names_regex": "exclude_param_names_regex",
|
|
288
|
-
"reduce_non_diagonal_ratio": "reduce_non_diagonal_ratio",
|
|
289
|
-
"weight_transpose": "weight_transpose",
|
|
290
|
-
}
|
|
291
286
|
|
|
292
287
|
def __init__(
|
|
293
288
|
self,
|
|
@@ -298,10 +293,6 @@ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
298
293
|
weight_transpose: bool,
|
|
299
294
|
**kwargs,
|
|
300
295
|
):
|
|
301
|
-
self.num_regmean_examples = num_regmean_examples
|
|
302
|
-
self.exclude_param_names_regex = exclude_param_names_regex
|
|
303
|
-
self.reduce_non_diagonal_ratio = reduce_non_diagonal_ratio
|
|
304
|
-
self.weight_transpose = weight_transpose
|
|
305
296
|
super().__init__(**kwargs)
|
|
306
297
|
|
|
307
298
|
def run(self, modelpool: BaseModelPool, **kwargs):
|
|
@@ -28,7 +28,7 @@ class RegMeanAlgorithmForCLIPPlusPlus(
|
|
|
28
28
|
|
|
29
29
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
30
30
|
super().__init__(**kwargs)
|
|
31
|
-
self.
|
|
31
|
+
self.dataloader_kwargs = dataloader_kwargs
|
|
32
32
|
|
|
33
33
|
def on_regmean_start(self):
|
|
34
34
|
self.setup_zero_shot_classification_head()
|
|
@@ -125,27 +125,26 @@ class RegMeanAlgorithmForCLIPPlusPlus(
|
|
|
125
125
|
|
|
126
126
|
param_dict = {}
|
|
127
127
|
for name, param in model_to_merge_state_dict.items():
|
|
128
|
-
if name.startswith("vision_model.embeddings") or name.startswith(
|
|
128
|
+
if name.startswith("vision_model.embeddings") or name.startswith(
|
|
129
|
+
"vision_model.pre_layrnorm"
|
|
130
|
+
):
|
|
129
131
|
param_dict[name] = param
|
|
130
132
|
|
|
131
133
|
for param_name in param_dict.keys():
|
|
132
|
-
models_to_merge_param_dict[param_name].append(
|
|
133
|
-
param_dict[param_name]
|
|
134
|
-
)
|
|
134
|
+
models_to_merge_param_dict[param_name].append(param_dict[param_name])
|
|
135
135
|
|
|
136
136
|
# merge the parameters of the embedding layer
|
|
137
137
|
merged_params_dict = {}
|
|
138
138
|
for param_name, param_list in models_to_merge_param_dict.items():
|
|
139
139
|
merged_params_dict[param_name] = torch.stack(param_list).mean(dim=0)
|
|
140
|
-
|
|
140
|
+
|
|
141
141
|
return merged_params_dict
|
|
142
|
-
|
|
143
|
-
|
|
142
|
+
|
|
144
143
|
def get_input_for_first_layer(self, model: nn.Module, train_dataset):
|
|
145
144
|
# setup dataloader
|
|
146
145
|
train_dataset = CLIPDataset(train_dataset, self.clip_processor)
|
|
147
146
|
train_dataloader = DataLoader(
|
|
148
|
-
train_dataset, shuffle=True, **self.
|
|
147
|
+
train_dataset, shuffle=True, **self.dataloader_kwargs
|
|
149
148
|
)
|
|
150
149
|
train_dataloader = self.fabric.setup_dataloaders(train_dataloader)
|
|
151
150
|
model = self.fabric.setup(model)
|
|
@@ -157,9 +156,9 @@ class RegMeanAlgorithmForCLIPPlusPlus(
|
|
|
157
156
|
image_embeds = model.vision_model.embeddings(images)
|
|
158
157
|
image_embeds = model.vision_model.pre_layrnorm(image_embeds)
|
|
159
158
|
image_embeds = image_embeds.detach().cpu()
|
|
160
|
-
|
|
159
|
+
|
|
161
160
|
return image_embeds
|
|
162
|
-
|
|
161
|
+
|
|
163
162
|
num_computed_examples = 0
|
|
164
163
|
num_regmean_examples = self.num_regmean_examples
|
|
165
164
|
|
|
@@ -169,24 +168,32 @@ class RegMeanAlgorithmForCLIPPlusPlus(
|
|
|
169
168
|
break
|
|
170
169
|
batches_input.append(compute_input(model, batch))
|
|
171
170
|
num_computed_examples += batch[0].size(0)
|
|
172
|
-
|
|
171
|
+
|
|
173
172
|
return batches_input
|
|
174
173
|
|
|
175
174
|
def get_layers(self, model: nn.Module):
|
|
176
175
|
return model.vision_model.encoder.layers
|
|
177
|
-
|
|
178
|
-
def update_merged_params_dict(
|
|
176
|
+
|
|
177
|
+
def update_merged_params_dict(
|
|
178
|
+
self, merged_params_dict, new_merged_params, layer_idx
|
|
179
|
+
):
|
|
179
180
|
for key, value in new_merged_params.items():
|
|
180
181
|
key = f"vision_model.encoder.layers.{layer_idx}.{key}"
|
|
181
182
|
merged_params_dict[key] = value
|
|
182
183
|
|
|
183
184
|
return merged_params_dict
|
|
184
|
-
|
|
185
|
-
def layer_batches_forward(
|
|
185
|
+
|
|
186
|
+
def layer_batches_forward(
|
|
187
|
+
self, layer: nn.Module, batches_input: List[Tensor]
|
|
188
|
+
) -> Tensor:
|
|
186
189
|
batches_output = []
|
|
187
190
|
for batch in batches_input:
|
|
188
191
|
device = next(layer.parameters()).device
|
|
189
192
|
batch = batch.to(device)
|
|
190
|
-
logits =
|
|
193
|
+
logits = (
|
|
194
|
+
layer(batch, attention_mask=None, causal_attention_mask=None)[0]
|
|
195
|
+
.detach()
|
|
196
|
+
.cpu()
|
|
197
|
+
)
|
|
191
198
|
batches_output.append(logits)
|
|
192
199
|
return batches_output
|
|
@@ -81,13 +81,11 @@ def regmean_params_merge(
|
|
|
81
81
|
reduce_non_diagonal_ratio: float = 1.0,
|
|
82
82
|
weight_transpose: bool = True,
|
|
83
83
|
module_name: str = "",
|
|
84
|
-
device
|
|
84
|
+
device="cpu",
|
|
85
85
|
):
|
|
86
86
|
# two lists with length num_models_to_merge
|
|
87
87
|
param_multiplied_results, module_regmean_weights_list = [], []
|
|
88
|
-
for model_idx, module_regmean_weights in enumerate(
|
|
89
|
-
param_regmean_list
|
|
90
|
-
):
|
|
88
|
+
for model_idx, module_regmean_weights in enumerate(param_regmean_list):
|
|
91
89
|
# reduce non-diagonal elements
|
|
92
90
|
module_regmean_weights = reduce_non_diagonal_elements(
|
|
93
91
|
regmean_weights=module_regmean_weights,
|
|
@@ -113,9 +111,7 @@ def regmean_params_merge(
|
|
|
113
111
|
sum_param_multiplied_results = sum(param_multiplied_results)
|
|
114
112
|
|
|
115
113
|
# get the inverse matrix
|
|
116
|
-
inv_sum_module_regmean_weights = torch.inverse(
|
|
117
|
-
sum_module_regmean_weights
|
|
118
|
-
)
|
|
114
|
+
inv_sum_module_regmean_weights = torch.inverse(sum_module_regmean_weights)
|
|
119
115
|
# merge parameters with regmean
|
|
120
116
|
merged_param = torch.matmul(
|
|
121
117
|
inv_sum_module_regmean_weights, sum_param_multiplied_results
|
|
@@ -158,15 +154,19 @@ def merging_with_regmean_weights(
|
|
|
158
154
|
device = param_value_list[model_idx].device
|
|
159
155
|
|
|
160
156
|
# Tensor, shape (hidden_dim, hidden_dim)
|
|
161
|
-
module_regmean_weights = model_to_merge_regmean_weights[
|
|
157
|
+
module_regmean_weights = model_to_merge_regmean_weights[
|
|
158
|
+
module_name
|
|
159
|
+
].to(device)
|
|
162
160
|
module_regmean_weights_list.append(module_regmean_weights)
|
|
163
161
|
|
|
164
|
-
merged_params[param_name] = regmean_params_merge(
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
162
|
+
merged_params[param_name] = regmean_params_merge(
|
|
163
|
+
param_weight_list=param_value_list,
|
|
164
|
+
param_regmean_list=module_regmean_weights_list,
|
|
165
|
+
reduce_non_diagonal_ratio=reduce_non_diagonal_ratio,
|
|
166
|
+
weight_transpose=weight_transpose,
|
|
167
|
+
module_name=module_name,
|
|
168
|
+
device=device,
|
|
169
|
+
)
|
|
170
170
|
|
|
171
171
|
merged_by_regmean = True
|
|
172
172
|
# use average merging for parameters whose names are not end with ".weight" or not in Linear module
|
|
@@ -205,7 +205,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
205
205
|
modelpool = BaseModelPool(modelpool)
|
|
206
206
|
self.modelpool = modelpool
|
|
207
207
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
208
|
-
models_to_merge_dict = {
|
|
208
|
+
models_to_merge_dict = {
|
|
209
|
+
name: model.to(device) for name, model in modelpool.named_models()
|
|
210
|
+
}
|
|
209
211
|
self.on_regmean_start()
|
|
210
212
|
|
|
211
213
|
# initialize the merged models as the pretrained model
|
|
@@ -213,7 +215,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
213
215
|
merged_params_dict = {}
|
|
214
216
|
|
|
215
217
|
# 1. merge embedding layer
|
|
216
|
-
merged_embedding_dict = self.merge_embedding_layer(
|
|
218
|
+
merged_embedding_dict = self.merge_embedding_layer(
|
|
219
|
+
models_to_merge_dict=models_to_merge_dict
|
|
220
|
+
)
|
|
217
221
|
merged_model.load_state_dict(merged_embedding_dict, strict=False)
|
|
218
222
|
|
|
219
223
|
with torch.no_grad():
|
|
@@ -223,12 +227,13 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
223
227
|
self.profile("computing first layer input"),
|
|
224
228
|
):
|
|
225
229
|
batches_input_dict = defaultdict(list)
|
|
226
|
-
for name in tqdm(
|
|
230
|
+
for name in tqdm(
|
|
231
|
+
models_to_merge_dict.keys(), desc="computing input for first layer"
|
|
232
|
+
):
|
|
227
233
|
dataset = modelpool.load_train_dataset(name)
|
|
228
|
-
|
|
234
|
+
|
|
229
235
|
batches_input_dict[name] = self.get_input_for_first_layer(
|
|
230
|
-
merged_model,
|
|
231
|
-
dataset
|
|
236
|
+
merged_model, dataset
|
|
232
237
|
)
|
|
233
238
|
|
|
234
239
|
# 2. iteratively merge layer by layer with regmean algorithm
|
|
@@ -240,9 +245,9 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
240
245
|
models_to_merge_layers_dict[name] = self.get_layers(model)
|
|
241
246
|
|
|
242
247
|
param_names_to_merge = None
|
|
243
|
-
for layer_idx, backbone_layer in tqdm(
|
|
244
|
-
|
|
245
|
-
|
|
248
|
+
for layer_idx, backbone_layer in tqdm(
|
|
249
|
+
enumerate(backbone_layers), desc="merging layers", total=num_layers
|
|
250
|
+
):
|
|
246
251
|
# dictionary of list, where key is the parameter name,
|
|
247
252
|
# value is a list of the corresponding parameters of all the models that need to be merged
|
|
248
253
|
models_to_merge_param_dict = defaultdict(list)
|
|
@@ -263,16 +268,19 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
263
268
|
"exclude_param_names_regex", []
|
|
264
269
|
),
|
|
265
270
|
)
|
|
266
|
-
|
|
271
|
+
|
|
267
272
|
for param_name in param_names_to_merge:
|
|
268
273
|
models_to_merge_param_dict[param_name].append(
|
|
269
274
|
param_dict[param_name]
|
|
270
275
|
)
|
|
271
276
|
|
|
272
277
|
linear_modules_to_merge = get_modules_to_merge(
|
|
273
|
-
model=layer_to_merge,
|
|
278
|
+
model=layer_to_merge,
|
|
279
|
+
include_module_types=self._include_module_type,
|
|
274
280
|
)
|
|
275
|
-
assert
|
|
281
|
+
assert (
|
|
282
|
+
len(linear_modules_to_merge) > 0
|
|
283
|
+
), "No linear modules to merge"
|
|
276
284
|
|
|
277
285
|
# 2.1. compute regmean weights for each model
|
|
278
286
|
with (
|
|
@@ -288,12 +296,19 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
288
296
|
|
|
289
297
|
module_subset = get_param_names_to_merge(
|
|
290
298
|
input_param_names=list(param_dict.keys()),
|
|
291
|
-
exclude_param_names_regex=self.exclude_param_names_regex
|
|
299
|
+
exclude_param_names_regex=self.exclude_param_names_regex,
|
|
292
300
|
)
|
|
293
|
-
module_subset = [
|
|
301
|
+
module_subset = [
|
|
302
|
+
name.replace(".weight", "").replace(".bias", "")
|
|
303
|
+
for name in module_subset
|
|
304
|
+
]
|
|
294
305
|
module_subset = list(set(module_subset))
|
|
295
|
-
regmean_weights = {
|
|
296
|
-
|
|
306
|
+
regmean_weights = {
|
|
307
|
+
module_name: regmean_weights[module_name]
|
|
308
|
+
for module_name in module_subset
|
|
309
|
+
if module_name in regmean_weights
|
|
310
|
+
}
|
|
311
|
+
|
|
297
312
|
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
298
313
|
|
|
299
314
|
# 2.2. merge parameters with regmean weights
|
|
@@ -318,21 +333,22 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
318
333
|
self.profile("forwarding next layer"),
|
|
319
334
|
):
|
|
320
335
|
if layer_idx < num_layers - 1:
|
|
321
|
-
backbone_layer.load_state_dict(
|
|
336
|
+
backbone_layer.load_state_dict(
|
|
337
|
+
merged_layer_params, strict=False
|
|
338
|
+
)
|
|
322
339
|
batches_output_dict = defaultdict(list)
|
|
323
340
|
for name in models_to_merge_dict.keys():
|
|
324
341
|
batches_output_dict[name] = self.layer_batches_forward(
|
|
325
|
-
backbone_layer,
|
|
326
|
-
batches_input_dict[name]
|
|
342
|
+
backbone_layer, batches_input_dict[name]
|
|
327
343
|
)
|
|
328
344
|
batches_input_dict = batches_output_dict
|
|
329
|
-
|
|
345
|
+
|
|
330
346
|
# 3. load state dict to the merged model
|
|
331
347
|
merged_model.load_state_dict(merged_params_dict, strict=False)
|
|
332
348
|
|
|
333
349
|
self.print_profile_summary()
|
|
334
350
|
return merged_model
|
|
335
|
-
|
|
351
|
+
|
|
336
352
|
def merge_embedding_layer(self, models_to_merge_dict: Dict[str, nn.Module]):
|
|
337
353
|
"""
|
|
338
354
|
Merge the embedding layer of the model with the merged model.
|
|
@@ -345,10 +361,12 @@ class RegMeanAlgorithmPlusPlus(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
345
361
|
|
|
346
362
|
def get_layers(self, model: nn.Module):
|
|
347
363
|
raise NotImplementedError
|
|
348
|
-
|
|
349
|
-
def update_merged_params_dict(
|
|
364
|
+
|
|
365
|
+
def update_merged_params_dict(
|
|
366
|
+
self, merged_params_dict, new_merged_params, layer_idx
|
|
367
|
+
):
|
|
350
368
|
raise NotImplementedError
|
|
351
|
-
|
|
369
|
+
|
|
352
370
|
def layer_batches_forward(self, layer: nn.Module, batches_input: List[Tensor]):
|
|
353
371
|
raise NotImplementedError
|
|
354
372
|
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
|
-
from fusion_bench.mixins
|
|
9
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
11
|
from fusion_bench.utils import LazyStateDict
|
|
12
12
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
@@ -59,24 +59,20 @@ def simple_average(
|
|
|
59
59
|
return state_dict_avg(modules)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
@auto_register_config
|
|
62
63
|
class SimpleAverageAlgorithm(
|
|
63
|
-
BaseAlgorithm,
|
|
64
64
|
SimpleProfilerMixin,
|
|
65
|
+
BaseAlgorithm,
|
|
65
66
|
):
|
|
66
|
-
|
|
67
|
-
"show_pbar": "show_pbar",
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
def __init__(self, show_pbar: bool = False):
|
|
67
|
+
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
71
68
|
"""
|
|
72
69
|
Args:
|
|
73
70
|
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
74
71
|
"""
|
|
75
|
-
super().__init__()
|
|
76
|
-
self.show_pbar = show_pbar
|
|
72
|
+
super().__init__(**kwargs)
|
|
77
73
|
|
|
78
74
|
@torch.no_grad()
|
|
79
|
-
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
75
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
|
|
80
76
|
"""
|
|
81
77
|
Fuse the models in the given model pool using simple averaging.
|
|
82
78
|
|
|
@@ -124,13 +120,13 @@ class SimpleAverageAlgorithm(
|
|
|
124
120
|
if isinstance(forward_model, LazyStateDict):
|
|
125
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
126
122
|
forward_model = forward_model.meta_module.to_empty(
|
|
127
|
-
device=
|
|
128
|
-
"cpu"
|
|
129
|
-
if forward_model._torch_dtype is None
|
|
130
|
-
else forward_model._torch_dtype
|
|
131
|
-
)
|
|
123
|
+
device=forward_model._device
|
|
132
124
|
)
|
|
133
|
-
forward_model.load_state_dict(sd)
|
|
125
|
+
result = forward_model.load_state_dict(sd, strict=False)
|
|
126
|
+
if result.unexpected_keys:
|
|
127
|
+
raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
|
|
128
|
+
if result.missing_keys:
|
|
129
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
134
130
|
# print profile report and log the merged models
|
|
135
131
|
self.print_profile_summary()
|
|
136
132
|
log.info(f"merged {len(merged_model_names)} models:")
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Any, Dict
|
|
2
3
|
|
|
3
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
4
6
|
from typing_extensions import override
|
|
5
7
|
|
|
6
8
|
from fusion_bench.method import BaseAlgorithm
|
|
7
9
|
from fusion_bench.modelpool import BaseModelPool
|
|
10
|
+
from fusion_bench.utils.type import StateDictType
|
|
8
11
|
|
|
9
12
|
from .slerp_utils import slerp
|
|
10
13
|
|
|
@@ -18,7 +21,7 @@ def slerp_on_state_dicts(
|
|
|
18
21
|
*,
|
|
19
22
|
DOT_THRESHOLD: float = 0.9995,
|
|
20
23
|
epsilon: float = 1e-8,
|
|
21
|
-
):
|
|
24
|
+
) -> StateDictType:
|
|
22
25
|
"""
|
|
23
26
|
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
24
27
|
|
|
@@ -72,7 +75,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
72
75
|
super().__init__()
|
|
73
76
|
|
|
74
77
|
@override
|
|
75
|
-
def run(self, modelpool: BaseModelPool):
|
|
78
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
76
79
|
"""
|
|
77
80
|
Run the SlerpMergeAlgorithm on the given model pool.
|
|
78
81
|
|