fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -13,6 +13,7 @@ from torch import Tensor, nn
|
|
|
13
13
|
from tqdm.autonotebook import tqdm
|
|
14
14
|
|
|
15
15
|
from fusion_bench.method import BaseAlgorithm
|
|
16
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
18
|
|
|
18
19
|
log = logging.getLogger(__name__)
|
|
@@ -279,7 +280,7 @@ def regmean_merging(
|
|
|
279
280
|
return merged_params
|
|
280
281
|
|
|
281
282
|
|
|
282
|
-
class RegMeanAlgorithm(BaseAlgorithm):
|
|
283
|
+
class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
283
284
|
_include_module_type = [nn.Linear]
|
|
284
285
|
_config_mapping = {
|
|
285
286
|
"num_regmean_examples": "num_regmean_examples",
|
|
@@ -342,24 +343,31 @@ class RegMeanAlgorithm(BaseAlgorithm):
|
|
|
342
343
|
)
|
|
343
344
|
assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
|
|
344
345
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
346
|
+
with (
|
|
347
|
+
self.profile("merging models"),
|
|
348
|
+
self.profile("computing regmean weights"),
|
|
349
|
+
):
|
|
350
|
+
regmean_weights = self.get_regmean_weights(
|
|
351
|
+
name,
|
|
352
|
+
model,
|
|
353
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
354
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
355
|
+
)
|
|
356
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
357
|
+
|
|
358
|
+
with self.profile("merging models"):
|
|
359
|
+
# merging with regmean weights
|
|
360
|
+
merged_params = merging_with_regmean_weights(
|
|
361
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
362
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
363
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
364
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
365
|
+
)
|
|
352
366
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
356
|
-
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
357
|
-
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
358
|
-
weight_transpose=self.config.get("weight_transpose", True),
|
|
359
|
-
)
|
|
367
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
368
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
360
369
|
|
|
361
|
-
|
|
362
|
-
merged_model.load_state_dict(merged_params, strict=False)
|
|
370
|
+
self.print_profile_summary()
|
|
363
371
|
return merged_model
|
|
364
372
|
|
|
365
373
|
def on_regmean_start(self):
|
|
@@ -442,16 +442,19 @@ class SmileUpscalingAlgorithm(
|
|
|
442
442
|
print_parameters(model)
|
|
443
443
|
return model
|
|
444
444
|
|
|
445
|
-
with self.profile("
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
445
|
+
with self.profile("loading model"):
|
|
446
|
+
# load models and move to GPU if available
|
|
447
|
+
with self.profile("load pretrained model"):
|
|
448
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
449
|
+
with self.profile("load fine-tuned model"):
|
|
450
|
+
finetuned_models = [
|
|
451
|
+
m
|
|
452
|
+
for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
if self.config.device == "cuda" and torch.cuda.is_available():
|
|
456
|
+
pretrained_model = pretrained_model.cuda()
|
|
457
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
455
458
|
|
|
456
459
|
with self.profile("merge model"):
|
|
457
460
|
model = self.merge(pretrained_model, finetuned_models)
|
|
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
|
|
|
85
85
|
|
|
86
86
|
if self.config.weights is not None:
|
|
87
87
|
# skip the test-time adaptation
|
|
88
|
+
merge_weight: torch.Tensor = torch.load(self.config.weights)
|
|
89
|
+
module.merge_weight.data = merge_weight.to(
|
|
90
|
+
device=module.merge_weight.device
|
|
91
|
+
)
|
|
88
92
|
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
93
|
+
# setup the zero-shot classification head
|
|
94
|
+
self.on_test_time_adaptation_start()
|
|
95
|
+
|
|
89
96
|
else:
|
|
90
97
|
with self.profile("test-time adaptation"):
|
|
91
98
|
module = self.test_time_adaptation(module)
|
|
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
|
|
9
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
from torch import nn
|
|
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
)
|
|
22
|
-
from fusion_bench.utils.type import StateDictType
|
|
22
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
23
|
|
|
24
24
|
log = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@torch.no_grad()
|
|
28
28
|
def task_arithmetic_merge(
|
|
29
|
-
pretrained_model:
|
|
30
|
-
finetuned_models: List[
|
|
29
|
+
pretrained_model: TorchModelType,
|
|
30
|
+
finetuned_models: List[TorchModelType],
|
|
31
31
|
scaling_factor: float,
|
|
32
32
|
inplace: bool = True,
|
|
33
|
-
) ->
|
|
33
|
+
) -> TorchModelType:
|
|
34
34
|
"""
|
|
35
35
|
Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
|
|
36
36
|
|
|
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
|
|
|
46
46
|
"""
|
|
47
47
|
if not inplace:
|
|
48
48
|
pretrained_model = deepcopy(pretrained_model)
|
|
49
|
-
task_vector: StateDictType = None
|
|
49
|
+
task_vector: Optional[StateDictType] = None
|
|
50
50
|
# Calculate the total task vector
|
|
51
51
|
for model in finetuned_models:
|
|
52
52
|
if task_vector is None:
|
|
53
|
+
# calculate the task vector for the first model
|
|
53
54
|
task_vector = state_dict_sub(
|
|
54
55
|
model.state_dict(keep_vars=True),
|
|
55
56
|
pretrained_model.state_dict(keep_vars=True),
|
|
56
57
|
)
|
|
57
58
|
else:
|
|
59
|
+
# calculate the task vector for the remaining models
|
|
58
60
|
task_vector = state_dict_add(
|
|
59
61
|
task_vector,
|
|
60
62
|
state_dict_sub(
|
|
@@ -16,6 +16,7 @@ from torch import Tensor, nn
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
20
21
|
from fusion_bench.utils.type import StateDictType
|
|
21
22
|
|
|
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
24
25
|
log = logging.getLogger(__name__)
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
class TiesMergingAlgorithm(BaseAlgorithm):
|
|
28
|
+
class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
28
29
|
"""
|
|
29
30
|
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
30
31
|
|
|
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
|
|
|
84
85
|
scaling_factor = self.scaling_factor
|
|
85
86
|
threshold = self.threshold
|
|
86
87
|
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
merged_check
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
88
|
+
with self.profile("loading models"):
|
|
89
|
+
# Load the pretrained model
|
|
90
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
91
|
+
|
|
92
|
+
# Load the state dicts of the models
|
|
93
|
+
ft_checks: List[StateDictType] = [
|
|
94
|
+
modelpool.load_model(model_name).state_dict(keep_vars=True)
|
|
95
|
+
for model_name in modelpool.model_names
|
|
96
|
+
]
|
|
97
|
+
ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
|
|
98
|
+
|
|
99
|
+
with self.profile("merging models"):
|
|
100
|
+
# Compute the task vectors
|
|
101
|
+
flat_ft: Tensor = torch.vstack(
|
|
102
|
+
[state_dict_to_vector(check, remove_keys) for check in ft_checks]
|
|
103
|
+
)
|
|
104
|
+
flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
|
|
105
|
+
tv_flat_checks = flat_ft - flat_ptm
|
|
106
|
+
|
|
107
|
+
# Perform TIES Merging
|
|
108
|
+
merged_tv = ties_merging(
|
|
109
|
+
tv_flat_checks,
|
|
110
|
+
reset_thresh=threshold,
|
|
111
|
+
merge_func=merge_func,
|
|
112
|
+
)
|
|
113
|
+
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
114
|
+
merged_state_dict = vector_to_state_dict(
|
|
115
|
+
merged_check, ptm_check, remove_keys=remove_keys
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Load the merged state dict into the pretrained model
|
|
119
|
+
pretrained_model.load_state_dict(merged_state_dict)
|
|
120
|
+
|
|
121
|
+
self.print_profile_summary()
|
|
117
122
|
return pretrained_model
|
|
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
7
7
|
import torch
|
|
8
|
-
from lightning.pytorch.profilers import SimpleProfiler
|
|
9
8
|
from omegaconf import DictConfig
|
|
10
9
|
from torch import Tensor
|
|
11
10
|
from torch.utils.data import DataLoader
|
|
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
|
|
|
13
12
|
|
|
14
13
|
from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
|
|
15
14
|
from fusion_bench.compat.modelpool import ModelPool
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
16
|
from fusion_bench.models.we_moe import WeightEnsemblingMoE
|
|
17
17
|
from fusion_bench.utils import timeit_context
|
|
18
18
|
from fusion_bench.utils.parameters import print_parameters
|
|
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
|
|
|
34
34
|
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class WeightEnsemblingMoEAlgorithm(
|
|
37
|
+
class WeightEnsemblingMoEAlgorithm(
|
|
38
|
+
ModelFusionAlgorithm,
|
|
39
|
+
SimpleProfilerMixin,
|
|
40
|
+
):
|
|
38
41
|
"""
|
|
39
42
|
Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
|
|
40
43
|
|
|
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
44
47
|
Attributes:
|
|
45
48
|
_fabric (L.Fabric): The fabric for distributed training.
|
|
46
49
|
modelpool (ModelPool): The pool of models to be fused.
|
|
47
|
-
profiler (SimpleProfiler): The profiler for measuring performance.
|
|
48
50
|
"""
|
|
49
51
|
|
|
50
52
|
_fabric: L.Fabric = None
|
|
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
66
68
|
self._fabric.launch()
|
|
67
69
|
else:
|
|
68
70
|
assert "No CUDA device available."
|
|
69
|
-
self.profiler = SimpleProfiler(
|
|
70
|
-
self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
|
|
71
|
-
)
|
|
72
71
|
|
|
73
72
|
@abstractmethod
|
|
74
73
|
def load_checkpoint(self, model, checkpoint):
|
|
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
177
176
|
for step_idx in pbar:
|
|
178
177
|
if self.config.use_grad_accumulate:
|
|
179
178
|
for task in self.modelpool.model_names:
|
|
180
|
-
with self.
|
|
179
|
+
with self.profile("data time"):
|
|
181
180
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
182
|
-
with self.
|
|
181
|
+
with self.profile("forward pass"):
|
|
183
182
|
logits = self.compute_logits(module, batch, task)
|
|
184
183
|
assert (
|
|
185
184
|
logits.dim() == 2
|
|
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
187
186
|
loss = entropy_loss(logits)
|
|
188
187
|
# .backward() accumulates when .zero_grad() wasn't called
|
|
189
188
|
# this can save memory
|
|
190
|
-
with self.
|
|
189
|
+
with self.profile("backward pass"):
|
|
191
190
|
self._fabric.backward(loss, retain_graph=True)
|
|
192
191
|
else:
|
|
193
192
|
loss = 0
|
|
194
193
|
for task in self.modelpool.model_names:
|
|
195
|
-
with self.
|
|
194
|
+
with self.profile("data time"):
|
|
196
195
|
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
197
|
-
with self.
|
|
196
|
+
with self.profile("forward pass"):
|
|
198
197
|
logits = self.compute_logits(module, batch, task)
|
|
199
198
|
assert (
|
|
200
199
|
logits.dim() == 2
|
|
201
200
|
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
202
201
|
loss = loss + entropy_loss(logits)
|
|
203
|
-
with self.
|
|
202
|
+
with self.profile("backward pass"):
|
|
204
203
|
self._fabric.backward(loss, retain_graph=True)
|
|
205
204
|
|
|
206
|
-
with self.
|
|
205
|
+
with self.profile("optimizer step"):
|
|
207
206
|
optimizer.step()
|
|
208
207
|
optimizer.zero_grad()
|
|
209
208
|
|
|
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
232
231
|
)
|
|
233
232
|
self.load_checkpoint(moe_model, self.config.checkpoint)
|
|
234
233
|
else:
|
|
235
|
-
with self.
|
|
234
|
+
with self.profile("test-time adaptation"):
|
|
236
235
|
moe_model = self.test_time_adaptation(moe_model)
|
|
237
236
|
if self.config.get("save_checkpoint", False):
|
|
238
237
|
log.info(f"save checkpoint to {self.config.save_checkpoint}")
|
|
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
|
|
|
243
242
|
|
|
244
243
|
# enable sample-wise adaptation
|
|
245
244
|
moe_model.batch_reduce = False
|
|
246
|
-
|
|
245
|
+
self.print_profile_summary()
|
|
247
246
|
return moe_model
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
|
|
|
6
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
7
7
|
|
|
8
8
|
_import_structure = {
|
|
9
|
+
"clip_classification": ["CLIPClassificationMixin"],
|
|
10
|
+
"fabric_training": ["FabricTrainingMixin"],
|
|
11
|
+
"hydra_config": ["HydraConfigMixin"],
|
|
9
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
|
+
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
10
14
|
"serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
|
|
11
15
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
12
|
-
"clip_classification": ["CLIPClassificationMixin"],
|
|
13
|
-
"fabric_training": ["FabricTrainingMixin"],
|
|
14
16
|
}
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
17
19
|
from .clip_classification import CLIPClassificationMixin
|
|
18
20
|
from .fabric_training import FabricTrainingMixin
|
|
21
|
+
from .hydra_config import HydraConfigMixin
|
|
19
22
|
from .lightning_fabric import LightningFabricMixin
|
|
23
|
+
from .openclip_classification import OpenCLIPClassificationMixin
|
|
20
24
|
from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
|
|
21
25
|
from .simple_profiler import SimpleProfilerMixin
|
|
22
|
-
|
|
23
26
|
else:
|
|
24
27
|
sys.modules[__name__] = LazyImporter(
|
|
25
28
|
__name__,
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import hydra.core.global_hydra
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils import import_object, instantiate
|
|
12
|
+
from fusion_bench.utils.instantiate import set_print_function_call
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HydraConfigMixin:
|
|
18
|
+
"""
|
|
19
|
+
A mixin for classes that need to be instantiated from a config file.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_config(
|
|
24
|
+
cls,
|
|
25
|
+
config_name: Union[str, Path],
|
|
26
|
+
overrides: Optional[List[str]] = None,
|
|
27
|
+
):
|
|
28
|
+
if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
|
|
29
|
+
raise RuntimeError("Hydra is not initialized.")
|
|
30
|
+
else:
|
|
31
|
+
cfg = compose(config_name=config_name, overrides=overrides)
|
|
32
|
+
|
|
33
|
+
config_groups = config_name.split("/")[:-1]
|
|
34
|
+
for config_group in config_groups:
|
|
35
|
+
cfg = cfg[config_group]
|
|
36
|
+
|
|
37
|
+
if "_target_" in cfg:
|
|
38
|
+
# if the config has a _target_ key, check if it is equal to the class name
|
|
39
|
+
target_cls = import_object(cfg["_target_"])
|
|
40
|
+
if target_cls != cls:
|
|
41
|
+
log.warning(
|
|
42
|
+
f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
|
|
43
|
+
)
|
|
44
|
+
with set_print_function_call(False):
|
|
45
|
+
obj = instantiate(cfg)
|
|
46
|
+
else:
|
|
47
|
+
obj = cls(**cfg)
|
|
48
|
+
|
|
49
|
+
return obj
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
+
from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
|
|
5
|
+
|
|
6
|
+
log = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
10
|
+
_train_processor = None
|
|
11
|
+
_test_processor = None
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from contextlib import contextmanager
|
|
2
|
-
from typing import Generator
|
|
2
|
+
from typing import Generator, Optional
|
|
3
3
|
|
|
4
4
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
5
5
|
from lightning.pytorch.profilers import SimpleProfiler
|
|
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
|
|
|
70
70
|
self.profiler.stop(action_name)
|
|
71
71
|
|
|
72
72
|
@rank_zero_only
|
|
73
|
-
def print_profile_summary(self):
|
|
73
|
+
def print_profile_summary(self, title: Optional[str] = None):
|
|
74
|
+
if title is not None:
|
|
75
|
+
print(title)
|
|
74
76
|
print(self.profiler.summary())
|
|
75
77
|
|
|
76
78
|
def __del__(self):
|
|
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
6
6
|
|
|
7
7
|
_import_structure = {
|
|
8
8
|
"base_pool": ["BaseModelPool"],
|
|
9
|
+
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
9
10
|
"clip_vision": ["CLIPVisionModelPool"],
|
|
10
11
|
"nyuv2_modelpool": ["NYUv2ModelPool"],
|
|
11
12
|
"huggingface_automodel": ["AutoModelPool"],
|
|
12
|
-
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
13
13
|
"seq2seq_lm": ["Seq2SeqLMPool"],
|
|
14
14
|
"PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
|
|
15
|
+
"openclip_vision": ["OpenCLIPVisionModelPool"],
|
|
15
16
|
"huggingface_gpt2_classification": [
|
|
16
17
|
"HuggingFaceGPT2ClassificationPool",
|
|
17
18
|
"GPT2ForSequenceClassificationPool",
|
|
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
30
31
|
HuggingFaceGPT2ClassificationPool,
|
|
31
32
|
)
|
|
32
33
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
34
|
+
from .openclip_vision import OpenCLIPVisionModelPool
|
|
33
35
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
34
36
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
35
37
|
from .seq_classification_lm import SeqenceClassificationModelPool
|
|
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench.mixins import BaseYAMLSerializableModel
|
|
10
|
+
from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
|
|
11
11
|
from fusion_bench.utils import instantiate, timeit_context
|
|
12
12
|
|
|
13
13
|
__all__ = ["BaseModelPool"]
|
|
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
|
|
|
15
15
|
log = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class BaseModelPool(BaseYAMLSerializableModel):
|
|
18
|
+
class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
19
19
|
"""
|
|
20
20
|
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
|
|
21
21
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modelpool import OpenCLIPVisionModelPool
|