fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import logging
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import List, Mapping, Union # noqa: F401
|
|
6
|
+
|
|
7
|
+
import lightning as L
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from tqdm.autonotebook import tqdm
|
|
15
|
+
|
|
16
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
17
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
18
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
19
|
+
TaskWiseMergedModel,
|
|
20
|
+
get_task_wise_weights,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# obtain the current GPU memory usage
|
|
27
|
+
def print_memory_usage(desc):
|
|
28
|
+
print(desc)
|
|
29
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
30
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
31
|
+
print(f"Allocated Memory: {allocated:.2f} MB")
|
|
32
|
+
print(f"Cached Memory: {cached:.2f} MB")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def entropy_loss(logits: Tensor) -> Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Compute the entropy loss of a set of logits.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Tensor: The entropy loss of the logits.
|
|
44
|
+
"""
|
|
45
|
+
probs = torch.softmax(logits, dim=-1)
|
|
46
|
+
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ModelScheduler:
|
|
50
|
+
"""
|
|
51
|
+
Manage the storage of models, schedule the order in which models are loaded to GPU
|
|
52
|
+
transfer data between the CPU and GPU
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
modelpool: ModelPool,
|
|
58
|
+
config: DictConfig,
|
|
59
|
+
):
|
|
60
|
+
self.pretrained_model = modelpool.load_model("_pretrained_")
|
|
61
|
+
self.finetuned_models = [
|
|
62
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
63
|
+
]
|
|
64
|
+
self.num_finetuned_models = len(self.finetuned_models)
|
|
65
|
+
self.new_finetuned_models = copy.deepcopy(self.finetuned_models)
|
|
66
|
+
self.finetuned_model_names = [name for name in modelpool.model_names]
|
|
67
|
+
|
|
68
|
+
self.config = config
|
|
69
|
+
|
|
70
|
+
@torch.no_grad() # not sure whether to use this
|
|
71
|
+
def __call__(self, model_id):
|
|
72
|
+
"""
|
|
73
|
+
return models and relevant data in each step
|
|
74
|
+
"""
|
|
75
|
+
# TODO: use a mixing matrix to determine which models to use in step idx
|
|
76
|
+
|
|
77
|
+
pretrained_model = copy.deepcopy(self.finetuned_models[model_id])
|
|
78
|
+
finetuned_models = [
|
|
79
|
+
copy.deepcopy(
|
|
80
|
+
self.finetuned_models[(model_id + 1) % self.num_finetuned_models]
|
|
81
|
+
),
|
|
82
|
+
copy.deepcopy(
|
|
83
|
+
self.finetuned_models[(model_id - 1) % self.num_finetuned_models]
|
|
84
|
+
),
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
if self.config.weights is None:
|
|
88
|
+
task_wise_weight = get_task_wise_weights(
|
|
89
|
+
num_models=len(finetuned_models),
|
|
90
|
+
init_values=self.config.init_values,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
module = TaskWiseMergedModel(
|
|
96
|
+
task_wise_weight=task_wise_weight,
|
|
97
|
+
pretrained_model=pretrained_model,
|
|
98
|
+
finetuned_models=finetuned_models,
|
|
99
|
+
clamp_weights=self.config.clamp_weights,
|
|
100
|
+
tie_weights=self.config.tie_weights,
|
|
101
|
+
strict=self.config.strict,
|
|
102
|
+
)
|
|
103
|
+
return module
|
|
104
|
+
|
|
105
|
+
def store_model(self, new_finetuned_model_dict, model_id):
|
|
106
|
+
"""
|
|
107
|
+
store new finetuned model after every turn of adamerging
|
|
108
|
+
"""
|
|
109
|
+
self.new_finetuned_models[model_id].load_state_dict(new_finetuned_model_dict)
|
|
110
|
+
|
|
111
|
+
def update_models(self):
|
|
112
|
+
self.finetuned_models = copy.deepcopy(self.new_finetuned_models)
|
|
113
|
+
|
|
114
|
+
def get_final_models(self):
|
|
115
|
+
# need a check
|
|
116
|
+
final_models = [
|
|
117
|
+
{"name": name, "model": model}
|
|
118
|
+
for name, model in zip(self.finetuned_model_names, self.finetuned_models)
|
|
119
|
+
]
|
|
120
|
+
num_finetuned_models = len(self.finetuned_models)
|
|
121
|
+
|
|
122
|
+
state_dict = self.pretrained_model.state_dict(keep_vars=True)
|
|
123
|
+
for name in state_dict.keys():
|
|
124
|
+
state_dict[name].data.zero_()
|
|
125
|
+
for model in self.finetuned_models:
|
|
126
|
+
for name, param in model.named_parameters():
|
|
127
|
+
state_dict[name] = state_dict[name] + 1 / num_finetuned_models * param
|
|
128
|
+
|
|
129
|
+
self.pretrained_model.load_state_dict(state_dict)
|
|
130
|
+
final_models += [{"name": "average model", "model": self.pretrained_model}]
|
|
131
|
+
|
|
132
|
+
return final_models
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class TaskWiseGossipAlgorithm(ModelFusionAlgorithm):
|
|
136
|
+
_fabric: L.Fabric = None
|
|
137
|
+
|
|
138
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
139
|
+
super().__init__(algorithm_config)
|
|
140
|
+
|
|
141
|
+
if self._fabric is None and torch.cuda.is_available():
|
|
142
|
+
self._fabric = L.Fabric(devices=self.config.get("devices", 1))
|
|
143
|
+
self._fabric.launch()
|
|
144
|
+
|
|
145
|
+
self.optimizer = None # we want to reuse it in Gossip using single GPU
|
|
146
|
+
|
|
147
|
+
def free_gpu_memory(self, module: TaskWiseMergedModel):
|
|
148
|
+
module.pretrained_model.to("cpu")
|
|
149
|
+
for model in module.task_vectors:
|
|
150
|
+
model.to("cpu")
|
|
151
|
+
del module
|
|
152
|
+
gc.collect()
|
|
153
|
+
torch.cuda.empty_cache()
|
|
154
|
+
print_memory_usage(
|
|
155
|
+
"finish local adamerging, after freeing memory, the memory usage of GPU is:"
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def run(self, modelpool: ModelPool):
|
|
159
|
+
log.info("Fusing models using task-wise adaptive merging with gossip.")
|
|
160
|
+
self.modelpool = modelpool
|
|
161
|
+
self.num_finetuned_models = len(modelpool.model_names)
|
|
162
|
+
|
|
163
|
+
model_scheduler = ModelScheduler(self.modelpool, self.config)
|
|
164
|
+
|
|
165
|
+
pbar = tqdm(
|
|
166
|
+
range(self.config.gossip_max_steps), "Gossip merging", dynamic_ncols=True
|
|
167
|
+
)
|
|
168
|
+
for step_idx in pbar:
|
|
169
|
+
log.info(f"step: {step_idx}")
|
|
170
|
+
for model_id in tqdm(
|
|
171
|
+
range(self.num_finetuned_models), "local adamerging", dynamic_ncols=True
|
|
172
|
+
):
|
|
173
|
+
# log.info(f"adamerging model: {model_scheduler.finetuned_midels_name[model_id]}")
|
|
174
|
+
module = model_scheduler(model_id)
|
|
175
|
+
module = self.test_time_adaptation(module)
|
|
176
|
+
# if self.config.get("save_merging_weights", False):
|
|
177
|
+
# torch.save(module.merge_weight, self.config.save_merging_weights)
|
|
178
|
+
print_memory_usage(
|
|
179
|
+
"local adamerging almost done, the memory usage of GPU is:"
|
|
180
|
+
)
|
|
181
|
+
model_scheduler.store_model(module.merge_weights(), model_id)
|
|
182
|
+
print_memory_usage(
|
|
183
|
+
"local adamerging almost done, the memory usage of GPU is:"
|
|
184
|
+
)
|
|
185
|
+
self.free_gpu_memory(
|
|
186
|
+
module
|
|
187
|
+
) # simulate distributed GPU memory usage as much as possible
|
|
188
|
+
|
|
189
|
+
model_scheduler.update_models()
|
|
190
|
+
|
|
191
|
+
return model_scheduler.get_final_models()
|
|
192
|
+
|
|
193
|
+
def on_test_time_adaptation_start(self):
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
@abstractmethod
|
|
197
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
198
|
+
pass
|
|
199
|
+
|
|
200
|
+
@abstractmethod
|
|
201
|
+
def compute_logits(self, module: nn.Module, batch, task: str) -> Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Compute the logits for the given batch and task.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
module (nn.Module): The model module.
|
|
207
|
+
batch (tuple): A batch of input data.
|
|
208
|
+
task (str): The name of the task.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Tensor: The classification logits for the batch.
|
|
212
|
+
"""
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
def test_time_adaptation(self, module: TaskWiseMergedModel):
|
|
216
|
+
self.on_test_time_adaptation_start()
|
|
217
|
+
|
|
218
|
+
# configure optimizer
|
|
219
|
+
if self.config.optimizer == "adam":
|
|
220
|
+
self.optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
223
|
+
|
|
224
|
+
if self._fabric is not None:
|
|
225
|
+
module, self.optimizer = self._fabric.setup(module, self.optimizer)
|
|
226
|
+
print_memory_usage(
|
|
227
|
+
"load model and optimizer to GPU, the memory usage of GPU is:"
|
|
228
|
+
)
|
|
229
|
+
module.train()
|
|
230
|
+
module.merge_weights()
|
|
231
|
+
|
|
232
|
+
if self.config.get("fast_dev_run", False):
|
|
233
|
+
log.info("Running fast_dev_run, only one step")
|
|
234
|
+
pbar = tqdm(
|
|
235
|
+
range(1),
|
|
236
|
+
"AdaMerging Test-time adaptation",
|
|
237
|
+
dynamic_ncols=True,
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
pbar = tqdm(
|
|
241
|
+
range(self.config.max_steps),
|
|
242
|
+
"AdaMerging Test-time adaptation",
|
|
243
|
+
dynamic_ncols=True,
|
|
244
|
+
)
|
|
245
|
+
for step_idx in pbar:
|
|
246
|
+
for task in self.modelpool.model_names:
|
|
247
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
248
|
+
logits = self.compute_logits(module, batch, task)
|
|
249
|
+
assert (
|
|
250
|
+
logits.dim() == 2
|
|
251
|
+
), f"Expected logits to be 2D, got {logits.dim()}"
|
|
252
|
+
loss = entropy_loss(logits)
|
|
253
|
+
# .backward() accumulates when .zero_grad() wasn't called
|
|
254
|
+
# this can save memory
|
|
255
|
+
self._fabric.backward(loss, retain_graph=True)
|
|
256
|
+
|
|
257
|
+
# print_memory_usage('model + dataset: ')
|
|
258
|
+
self.optimizer.step()
|
|
259
|
+
self.optimizer.zero_grad()
|
|
260
|
+
module.merge_weights()
|
|
261
|
+
|
|
262
|
+
del self.optimizer
|
|
263
|
+
gc.collect()
|
|
264
|
+
torch.cuda.empty_cache()
|
|
265
|
+
return module
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_memory_usage(desc):
|
|
9
|
+
"""
|
|
10
|
+
obtain the current GPU memory usage
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
str: A string containing the allocated and cached memory in MB.
|
|
14
|
+
"""
|
|
15
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
16
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
17
|
+
return (
|
|
18
|
+
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Model conversion utils
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def state_dict_to_vector(state_dict, remove_keys=[]):
|
|
26
|
+
"""
|
|
27
|
+
Convert a state dictionary to a vector.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
state_dict (dict): The state dictionary to convert.
|
|
31
|
+
remove_keys (list, optional): List of keys to remove from the state dictionary. Defaults to [].
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
torch.Tensor: The converted vector.
|
|
35
|
+
"""
|
|
36
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
37
|
+
for key in remove_keys:
|
|
38
|
+
if key in shared_state_dict:
|
|
39
|
+
del shared_state_dict[key]
|
|
40
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
41
|
+
return nn.utils.parameters_to_vector(
|
|
42
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def vector_to_state_dict(vector, state_dict, remove_keys=[]):
|
|
47
|
+
"""
|
|
48
|
+
Convert a vector to a state dictionary.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
vector (torch.Tensor): The vector to convert.
|
|
52
|
+
state_dict (dict): The reference state dictionary to define the order of the vector.
|
|
53
|
+
remove_keys (list, optional): List of keys to remove from the reference state dictionary. Defaults to [].
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
dict: The converted state dictionary.
|
|
57
|
+
"""
|
|
58
|
+
# create a reference dict to define the order of the vector
|
|
59
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
60
|
+
for key in remove_keys:
|
|
61
|
+
if key in reference_dict:
|
|
62
|
+
del reference_dict[key]
|
|
63
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
64
|
+
|
|
65
|
+
# create a shared state dict using the reference dict
|
|
66
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
67
|
+
|
|
68
|
+
# add back the encoder and decoder embedding weights.
|
|
69
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
70
|
+
for key in remove_keys:
|
|
71
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
72
|
+
"transformer.shared.weight"
|
|
73
|
+
]
|
|
74
|
+
return sorted_reference_dict
|
|
@@ -3,7 +3,7 @@ This module contains the implementation of the Isotropic Merging in Common Subsp
|
|
|
3
3
|
Modified from the original implementation: https://github.com/danielm1405/iso-merging
|
|
4
4
|
|
|
5
5
|
Reference:
|
|
6
|
-
- Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
|
|
6
|
+
- Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. 2025.
|
|
7
7
|
https://arxiv.org/abs/2502.04959
|
|
8
8
|
"""
|
|
9
9
|
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
|
|
|
15
15
|
from transformers import CLIPVisionModel
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
19
19
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
20
|
from fusion_bench.utils import instantiate
|
|
21
21
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
31
31
|
class OPCMForCLIP(
|
|
32
32
|
BaseAlgorithm,
|
|
33
33
|
LightningFabricMixin,
|
|
34
|
+
SimpleProfilerMixin,
|
|
34
35
|
):
|
|
35
36
|
def __init__(
|
|
36
37
|
self,
|
|
@@ -64,7 +65,8 @@ class OPCMForCLIP(
|
|
|
64
65
|
L.seed_everything(self.seed)
|
|
65
66
|
accelerator = self.fabric.device
|
|
66
67
|
|
|
67
|
-
|
|
68
|
+
with self.profile("loading model"):
|
|
69
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
68
70
|
|
|
69
71
|
model_names = modelpool.model_names
|
|
70
72
|
if self.shuffle_order:
|
|
@@ -83,15 +85,17 @@ class OPCMForCLIP(
|
|
|
83
85
|
)
|
|
84
86
|
|
|
85
87
|
# get the average model
|
|
86
|
-
|
|
88
|
+
with self.profile("loading model"):
|
|
89
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
87
90
|
|
|
88
91
|
if self.evaluate_on_every_step:
|
|
89
|
-
self.
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
with self.profile("evaluating model"):
|
|
93
|
+
self.taskpool._is_setup = False
|
|
94
|
+
self.taskpool._test_datasets = DictConfig(
|
|
95
|
+
{model_names[0]: self._test_datasets[model_names[0]]}
|
|
96
|
+
)
|
|
97
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
98
|
+
save_to_json(report, Path(self.log_dir) / "report_0.json")
|
|
95
99
|
|
|
96
100
|
self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
97
101
|
self.all_task_vector_norm = [self.avg_task_vector_norm]
|
|
@@ -113,90 +117,104 @@ class OPCMForCLIP(
|
|
|
113
117
|
enumerate(model_names[1:]), desc="Processing models"
|
|
114
118
|
):
|
|
115
119
|
model_idx += 1
|
|
116
|
-
|
|
120
|
+
with self.profile("loading model"):
|
|
121
|
+
task_model = modelpool.load_model(model_name)
|
|
117
122
|
|
|
118
|
-
self.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
123
|
+
with self.profile("merging model"):
|
|
124
|
+
self.all_task_vector_norm.append(
|
|
125
|
+
get_task_vector_norm(task_model, pretrained_model)
|
|
126
|
+
)
|
|
127
|
+
self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
|
|
128
|
+
self.fabric.log(
|
|
129
|
+
"model/task_vector_norm",
|
|
130
|
+
self.all_task_vector_norm[-1],
|
|
131
|
+
step=model_idx,
|
|
132
|
+
)
|
|
133
|
+
self.fabric.log(
|
|
134
|
+
"model/avg_task_vector_norm",
|
|
135
|
+
self.avg_task_vector_norm,
|
|
136
|
+
step=model_idx,
|
|
137
|
+
)
|
|
128
138
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
accelerator=accelerator,
|
|
147
|
-
)
|
|
148
|
-
if module.bias is not None:
|
|
149
|
-
module.bias.data = self.merge_other_parameters(
|
|
150
|
-
module.bias,
|
|
151
|
-
pretrained_model.get_submodule(module_name).bias,
|
|
152
|
-
task_model.get_submodule(module_name).bias,
|
|
153
|
-
param_name=".".join([module_name, "bias"]),
|
|
139
|
+
self.lambda_t = 1 # temporary value
|
|
140
|
+
|
|
141
|
+
for module_name, module in tqdm(
|
|
142
|
+
list(merged_model.named_modules()),
|
|
143
|
+
desc=f"Processing {model_name}",
|
|
144
|
+
leave=False,
|
|
145
|
+
):
|
|
146
|
+
if not is_leaf_module(module):
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
if isinstance(module, nn.Linear):
|
|
150
|
+
module.weight.data = self.merge_linear_weights(
|
|
151
|
+
module.weight,
|
|
152
|
+
pretrained_model.get_submodule(module_name).weight,
|
|
153
|
+
task_model.get_submodule(module_name).weight,
|
|
154
|
+
param_name=".".join([module_name, "weight"]),
|
|
155
|
+
alpha=self.alpha,
|
|
154
156
|
accelerator=accelerator,
|
|
155
157
|
)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
module_name
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
158
|
+
if module.bias is not None:
|
|
159
|
+
module.bias.data = self.merge_other_parameters(
|
|
160
|
+
module.bias,
|
|
161
|
+
pretrained_model.get_submodule(module_name).bias,
|
|
162
|
+
task_model.get_submodule(module_name).bias,
|
|
163
|
+
param_name=".".join([module_name, "bias"]),
|
|
164
|
+
accelerator=accelerator,
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
for param_name, param in module.named_parameters():
|
|
168
|
+
param.data = self.merge_other_parameters(
|
|
169
|
+
merged_W=param,
|
|
170
|
+
pretrained_W=pretrained_model.get_submodule(
|
|
171
|
+
module_name
|
|
172
|
+
).get_parameter(param_name),
|
|
173
|
+
task_W=task_model.get_submodule(
|
|
174
|
+
module_name
|
|
175
|
+
).get_parameter(param_name),
|
|
176
|
+
param_name=".".join([module_name, param_name]),
|
|
177
|
+
accelerator=accelerator,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
181
|
+
self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
|
|
182
|
+
for param_name, param in merged_model.named_parameters():
|
|
183
|
+
param.data = pretrained_model.get_parameter(param_name) + (
|
|
184
|
+
param - pretrained_model.get_parameter(param_name)
|
|
185
|
+
) * (self.avg_task_vector_norm / task_vector_norm)
|
|
186
|
+
self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
|
|
187
|
+
self.fabric.log(
|
|
188
|
+
"empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
|
|
189
|
+
)
|
|
190
|
+
self.previous_lambda_t = self.lambda_t
|
|
191
|
+
self.lambda_t = None
|
|
182
192
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
193
|
+
self.fabric.log(
|
|
194
|
+
"model/merged_task_vector_norm",
|
|
195
|
+
get_task_vector_norm(merged_model, pretrained_model),
|
|
196
|
+
step=model_idx,
|
|
197
|
+
)
|
|
188
198
|
|
|
189
199
|
if self.save_on_every_step:
|
|
190
|
-
self.
|
|
200
|
+
with self.profile("saving model"):
|
|
201
|
+
self.save_merged_model(merged_model, model_idx)
|
|
191
202
|
|
|
192
203
|
if self.evaluate_on_every_step:
|
|
193
|
-
self.
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
204
|
+
with self.profile("evaluating model"):
|
|
205
|
+
self.taskpool._is_setup = False
|
|
206
|
+
self.taskpool._test_datasets = DictConfig(
|
|
207
|
+
{
|
|
208
|
+
n: self._test_datasets[n]
|
|
209
|
+
for n in model_names[: model_idx + 1]
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
213
|
+
save_to_json(
|
|
214
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
215
|
+
)
|
|
199
216
|
|
|
217
|
+
self.print_profile_summary()
|
|
200
218
|
return merged_model
|
|
201
219
|
|
|
202
220
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
@@ -227,7 +245,7 @@ class OPCMForCLIP(
|
|
|
227
245
|
split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
|
|
228
246
|
|
|
229
247
|
projected_task_tv = u.T @ task_tv @ v
|
|
230
|
-
projected_task_tv.
|
|
248
|
+
projected_task_tv.diagonal().fill_(0)
|
|
231
249
|
|
|
232
250
|
projected_task_tv[:split_rank, :split_rank] = 0
|
|
233
251
|
|
|
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
|
|
|
15
15
|
from transformers import CLIPVisionModel
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
19
19
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
20
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
21
21
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
@@ -24,7 +24,11 @@ if TYPE_CHECKING:
|
|
|
24
24
|
from torch.utils.tensorboard import SummaryWriter
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class ContinualTaskArithmeticForCLIP(
|
|
27
|
+
class ContinualTaskArithmeticForCLIP(
|
|
28
|
+
BaseAlgorithm,
|
|
29
|
+
LightningFabricMixin,
|
|
30
|
+
SimpleProfilerMixin,
|
|
31
|
+
):
|
|
28
32
|
def __init__(
|
|
29
33
|
self,
|
|
30
34
|
scaling_factor: float,
|
|
@@ -79,32 +83,42 @@ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
79
83
|
for model_idx, model_name in tqdm(
|
|
80
84
|
enumerate(model_names), desc="Processing models"
|
|
81
85
|
):
|
|
82
|
-
|
|
86
|
+
with self.profile("loading model"):
|
|
87
|
+
task_model = modelpool.load_model(model_name)
|
|
83
88
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
89
|
+
with self.profile("merging model"):
|
|
90
|
+
for param_name, param in task_model.named_parameters():
|
|
91
|
+
if not param.requires_grad:
|
|
92
|
+
continue
|
|
87
93
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
94
|
+
task_param = param
|
|
95
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
96
|
+
pretrained_param = pretrained_model.get_parameter(param_name)
|
|
91
97
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
98
|
+
new_param = merged_param + self.scaling_factor * (
|
|
99
|
+
task_param - pretrained_param
|
|
100
|
+
)
|
|
101
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
96
102
|
|
|
97
103
|
if self.save_on_every_step:
|
|
98
|
-
self.
|
|
104
|
+
with self.profile("saving model"):
|
|
105
|
+
self.save_merged_model(merged_model, model_idx)
|
|
99
106
|
|
|
100
107
|
if self.evaluate_on_every_step:
|
|
101
|
-
self.
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
+
with self.profile("evaluating model"):
|
|
109
|
+
self.taskpool._is_setup = False
|
|
110
|
+
self.taskpool._test_datasets = DictConfig(
|
|
111
|
+
{
|
|
112
|
+
n: self._test_datasets[n]
|
|
113
|
+
for n in model_names[: model_idx + 1]
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
117
|
+
save_to_json(
|
|
118
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.print_profile_summary()
|
|
108
122
|
return merged_model
|
|
109
123
|
|
|
110
124
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|