fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
|
7
|
+
|
|
8
|
+
import lightning.fabric
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import torch
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
from omegaconf import DictConfig, OmegaConf
|
|
14
|
+
from open_clip.model import ResidualAttentionBlock
|
|
15
|
+
from torch import Tensor, nn
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
from fusion_bench import BaseAlgorithm
|
|
19
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
20
|
+
from fusion_bench.method.task_arithmetic import task_arithmetic_merge
|
|
21
|
+
from fusion_bench.mixins import OpenCLIPClassificationMixin, SimpleProfilerMixin
|
|
22
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
23
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
24
|
+
from fusion_bench.utils import print_parameters, timeit_context
|
|
25
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
+
|
|
27
|
+
from .module import ParetoWeightEnsemblingModule
|
|
28
|
+
from .phn.solvers import EPOSolver
|
|
29
|
+
from .utils import generate_simplex_grid
|
|
30
|
+
|
|
31
|
+
log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PWEMoEAlgorithmForOpenCLIP(
|
|
35
|
+
BaseAlgorithm,
|
|
36
|
+
SimpleProfilerMixin,
|
|
37
|
+
OpenCLIPClassificationMixin,
|
|
38
|
+
):
|
|
39
|
+
modelpool: OpenCLIPVisionModelPool
|
|
40
|
+
|
|
41
|
+
#! === Training & Validation Data ===
|
|
42
|
+
# setup the datasets and loaders by calling `load_datasets`
|
|
43
|
+
train_datasets: Dict[str, CLIPDataset]
|
|
44
|
+
train_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
45
|
+
train_loader_iters: Dict[str, Iterator[Tuple[torch.Tensor, torch.Tensor]]]
|
|
46
|
+
|
|
47
|
+
test_datasets: Dict[str, CLIPDataset]
|
|
48
|
+
test_loaders: Dict[str, torch.utils.data.DataLoader]
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
*,
|
|
53
|
+
#! === Model Architecture Arguments ===
|
|
54
|
+
partial: bool,
|
|
55
|
+
init_lambda: float,
|
|
56
|
+
router_hidden_layers: int,
|
|
57
|
+
checkpoint_path: str,
|
|
58
|
+
#! === Training Arguments ===
|
|
59
|
+
run_train: bool,
|
|
60
|
+
num_steps: int,
|
|
61
|
+
save_interval: int,
|
|
62
|
+
lr: float,
|
|
63
|
+
alpha: float,
|
|
64
|
+
dataloader_kwargs: DictConfig,
|
|
65
|
+
#! === Evaluation Arguments ===
|
|
66
|
+
run_eval: bool,
|
|
67
|
+
num_evaluation_samples: Union[str, int],
|
|
68
|
+
quick_evaluation: bool,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__(**kwargs)
|
|
72
|
+
self.partial = partial
|
|
73
|
+
self.init_lambda = init_lambda
|
|
74
|
+
self.router_hidden_layers = router_hidden_layers
|
|
75
|
+
self.lr = lr
|
|
76
|
+
self.num_steps = num_steps
|
|
77
|
+
self.save_interval = save_interval
|
|
78
|
+
self.alpha = alpha
|
|
79
|
+
self.checkpoint_path = checkpoint_path
|
|
80
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
81
|
+
self.run_train = run_train
|
|
82
|
+
self.run_eval = run_eval
|
|
83
|
+
self.num_evaluation_samples = num_evaluation_samples
|
|
84
|
+
self.quick_evaluation = quick_evaluation
|
|
85
|
+
|
|
86
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
87
|
+
self.modelpool = modelpool
|
|
88
|
+
|
|
89
|
+
# setup the MoE model
|
|
90
|
+
model = self.load_model()
|
|
91
|
+
if self.checkpoint_path is not None:
|
|
92
|
+
self.fabric.load(self.checkpoint_path, {"model": model})
|
|
93
|
+
|
|
94
|
+
# setup dataloaders
|
|
95
|
+
self.load_datasets()
|
|
96
|
+
|
|
97
|
+
if self.run_train:
|
|
98
|
+
model = self.train()
|
|
99
|
+
if self.run_eval:
|
|
100
|
+
self.evaluate(model)
|
|
101
|
+
return model
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def load_model(self):
|
|
105
|
+
modelpool = self.modelpool
|
|
106
|
+
|
|
107
|
+
# load models and classification heads
|
|
108
|
+
pretrained_model: ImageEncoder = self.modelpool.load_pretrained_model()
|
|
109
|
+
log.info("pretrained model statistics:")
|
|
110
|
+
print_parameters(pretrained_model, print_fn=log.info)
|
|
111
|
+
|
|
112
|
+
finetuned_models: Dict[str, ImageEncoder] = {}
|
|
113
|
+
for model_name in self.modelpool.model_names:
|
|
114
|
+
finetuned_models[model_name] = modelpool.load_model(model_name)
|
|
115
|
+
|
|
116
|
+
classification_heads: Dict[str, ClassificationHead] = {}
|
|
117
|
+
for model_name in self.modelpool.model_names:
|
|
118
|
+
classification_heads[model_name] = modelpool.load_classification_head(
|
|
119
|
+
model_name
|
|
120
|
+
)
|
|
121
|
+
self.classification_heads = classification_heads
|
|
122
|
+
|
|
123
|
+
self.train_processor = modelpool.train_processor
|
|
124
|
+
self.test_processor = modelpool.test_processor
|
|
125
|
+
|
|
126
|
+
with timeit_context("Building the MoE model"):
|
|
127
|
+
model = deepcopy(pretrained_model)
|
|
128
|
+
|
|
129
|
+
if self.partial:
|
|
130
|
+
log.info("Weight ensembling only the MLPs")
|
|
131
|
+
# weight ensembling only the MLPs, merge the remaining layers using task arithmetic
|
|
132
|
+
model = task_arithmetic_merge(
|
|
133
|
+
pretrained_model=model,
|
|
134
|
+
finetuned_models=list(finetuned_models.values()),
|
|
135
|
+
scaling_factor=self.init_lambda,
|
|
136
|
+
inplace=True,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# fix all parameters
|
|
140
|
+
model.requires_grad_(False)
|
|
141
|
+
|
|
142
|
+
for layer_idx in tqdm(
|
|
143
|
+
range(model.model.visual.transformer.layers), desc="Upscaling MLPs"
|
|
144
|
+
):
|
|
145
|
+
resblock: ResidualAttentionBlock = (
|
|
146
|
+
model.model.visual.transformer.resblocks[layer_idx]
|
|
147
|
+
)
|
|
148
|
+
resblock.mlp = ParetoWeightEnsemblingModule(
|
|
149
|
+
base_model=cast(
|
|
150
|
+
ResidualAttentionBlock,
|
|
151
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
152
|
+
layer_idx
|
|
153
|
+
],
|
|
154
|
+
).mlp,
|
|
155
|
+
expert_models=[
|
|
156
|
+
cast(
|
|
157
|
+
ResidualAttentionBlock,
|
|
158
|
+
m.model.visual.transformer.resblocks[layer_idx],
|
|
159
|
+
).mlp
|
|
160
|
+
for m in finetuned_models.values()
|
|
161
|
+
],
|
|
162
|
+
init_lambda=self.init_lambda,
|
|
163
|
+
fix_base_model_and_experts=True,
|
|
164
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
log.info("Weight ensembling all the layers")
|
|
168
|
+
# weight ensembling all the layers, merge the remaining layers using task arithmetic
|
|
169
|
+
model = task_arithmetic_merge(
|
|
170
|
+
pretrained_model=model,
|
|
171
|
+
finetuned_models=list(finetuned_models.values()),
|
|
172
|
+
scaling_factor=self.init_lambda,
|
|
173
|
+
inplace=True,
|
|
174
|
+
)
|
|
175
|
+
# fix all parameters
|
|
176
|
+
model.requires_grad_(False)
|
|
177
|
+
|
|
178
|
+
for name in [
|
|
179
|
+
"conv1",
|
|
180
|
+
"ln_pre",
|
|
181
|
+
"ln_post",
|
|
182
|
+
# "class_embedding",
|
|
183
|
+
# "positional_embedding",
|
|
184
|
+
]:
|
|
185
|
+
setattr(
|
|
186
|
+
model.model.visual,
|
|
187
|
+
name,
|
|
188
|
+
ParetoWeightEnsemblingModule(
|
|
189
|
+
base_model=getattr(pretrained_model.model.visual, name),
|
|
190
|
+
expert_models=[
|
|
191
|
+
getattr(m.model.visual, name)
|
|
192
|
+
for m in finetuned_models.values()
|
|
193
|
+
],
|
|
194
|
+
init_lambda=self.init_lambda,
|
|
195
|
+
fix_base_model_and_experts=True,
|
|
196
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
197
|
+
),
|
|
198
|
+
)
|
|
199
|
+
for layer_idx in tqdm(
|
|
200
|
+
range(model.model.visual.transformer.layers),
|
|
201
|
+
desc="Upscaling the transformer layers",
|
|
202
|
+
):
|
|
203
|
+
for name in ["ln_1", "attn", "ln_attn", "ln_2", "mlp"]:
|
|
204
|
+
setattr(
|
|
205
|
+
model.model.visual.transformer.resblocks[layer_idx],
|
|
206
|
+
name,
|
|
207
|
+
ParetoWeightEnsemblingModule(
|
|
208
|
+
base_model=getattr(
|
|
209
|
+
cast(
|
|
210
|
+
ResidualAttentionBlock,
|
|
211
|
+
pretrained_model.model.visual.transformer.resblocks[
|
|
212
|
+
layer_idx
|
|
213
|
+
],
|
|
214
|
+
),
|
|
215
|
+
name,
|
|
216
|
+
),
|
|
217
|
+
expert_models=[
|
|
218
|
+
getattr(
|
|
219
|
+
cast(
|
|
220
|
+
ResidualAttentionBlock,
|
|
221
|
+
m.model.visual.transformer.resblocks[
|
|
222
|
+
layer_idx
|
|
223
|
+
],
|
|
224
|
+
),
|
|
225
|
+
name,
|
|
226
|
+
)
|
|
227
|
+
for m in finetuned_models.values()
|
|
228
|
+
],
|
|
229
|
+
init_lambda=self.init_lambda,
|
|
230
|
+
fix_base_model_and_experts=True,
|
|
231
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
for name in ["token_embedding", "ln_final"]:
|
|
235
|
+
setattr(
|
|
236
|
+
model.model,
|
|
237
|
+
name,
|
|
238
|
+
ParetoWeightEnsemblingModule(
|
|
239
|
+
base_model=getattr(pretrained_model.model, name),
|
|
240
|
+
expert_models=[
|
|
241
|
+
getattr(m.model, name)
|
|
242
|
+
for m in finetuned_models.values()
|
|
243
|
+
],
|
|
244
|
+
init_lambda=self.init_lambda,
|
|
245
|
+
fix_base_model_and_experts=True,
|
|
246
|
+
router_hidden_layers=self.router_hidden_layers,
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self.model = model
|
|
251
|
+
print_parameters(model, print_fn=log.info)
|
|
252
|
+
return model
|
|
253
|
+
|
|
254
|
+
def load_datasets(self):
|
|
255
|
+
modelpool = self.modelpool
|
|
256
|
+
|
|
257
|
+
# setup the train datasets and loaders
|
|
258
|
+
train_datasets = {}
|
|
259
|
+
train_loaders = {}
|
|
260
|
+
train_loader_iters = {}
|
|
261
|
+
for dataset_name in modelpool.train_dataset_names:
|
|
262
|
+
train_datasets[dataset_name] = modelpool.load_train_dataset(dataset_name)
|
|
263
|
+
train_datasets[dataset_name] = CLIPDataset(
|
|
264
|
+
train_datasets[dataset_name], self.train_processor
|
|
265
|
+
)
|
|
266
|
+
# sanity check
|
|
267
|
+
assert isinstance(train_datasets[dataset_name][0], tuple)
|
|
268
|
+
|
|
269
|
+
# setup the train loaders
|
|
270
|
+
train_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
271
|
+
train_datasets[dataset_name],
|
|
272
|
+
shuffle=True,
|
|
273
|
+
drop_last=True,
|
|
274
|
+
**self._dataloader_kwargs,
|
|
275
|
+
)
|
|
276
|
+
train_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
277
|
+
train_loaders[dataset_name]
|
|
278
|
+
)
|
|
279
|
+
train_loaders[dataset_name] = InfiniteDataLoader(
|
|
280
|
+
train_loaders[dataset_name]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# setup the train loader iterators
|
|
284
|
+
train_loader_iters[dataset_name] = iter(train_loaders[dataset_name])
|
|
285
|
+
|
|
286
|
+
self.train_datasets = train_datasets
|
|
287
|
+
self.train_loaders = train_loaders
|
|
288
|
+
self.train_loader_iters = train_loader_iters
|
|
289
|
+
|
|
290
|
+
# setup the test datasets and loaders
|
|
291
|
+
test_datasets = {}
|
|
292
|
+
test_loaders = {}
|
|
293
|
+
for dataset_name in modelpool.test_dataset_names:
|
|
294
|
+
test_datasets[dataset_name] = modelpool.load_test_dataset(dataset_name)
|
|
295
|
+
test_datasets[dataset_name] = CLIPDataset(
|
|
296
|
+
test_datasets[dataset_name], self.test_processor
|
|
297
|
+
)
|
|
298
|
+
test_loaders[dataset_name] = torch.utils.data.DataLoader(
|
|
299
|
+
test_datasets[dataset_name],
|
|
300
|
+
shuffle=False,
|
|
301
|
+
**self._dataloader_kwargs,
|
|
302
|
+
)
|
|
303
|
+
test_loaders[dataset_name] = self.fabric.setup_dataloaders(
|
|
304
|
+
test_loaders[dataset_name]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
self.test_datasets = test_datasets
|
|
308
|
+
self.test_loaders = test_loaders
|
|
309
|
+
|
|
310
|
+
def compute_loss(self, model: ImageEncoder, ray: Tensor):
|
|
311
|
+
losses = []
|
|
312
|
+
for dataset_idx, dataset_name in enumerate(self.train_datasets):
|
|
313
|
+
batch = next(self.train_loader_iters[dataset_name])
|
|
314
|
+
x, y = batch
|
|
315
|
+
|
|
316
|
+
features = model(x)
|
|
317
|
+
logits = self.classification_heads[dataset_name](features)
|
|
318
|
+
|
|
319
|
+
_loss = F.cross_entropy(logits, y)
|
|
320
|
+
losses.append(_loss)
|
|
321
|
+
|
|
322
|
+
loss = self.aggregate_loss(model, ray, losses)
|
|
323
|
+
return loss
|
|
324
|
+
|
|
325
|
+
@abstractmethod
|
|
326
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
327
|
+
pass
|
|
328
|
+
|
|
329
|
+
def train(self):
|
|
330
|
+
# setup the model
|
|
331
|
+
num_objectives = len(self.modelpool.model_names)
|
|
332
|
+
model = deepcopy(self.model)
|
|
333
|
+
self.classification_heads = {
|
|
334
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
# set up the optimizer and learning rate scheduler
|
|
338
|
+
optimizer = torch.optim.Adam(
|
|
339
|
+
filter(lambda p: p.requires_grad, model.parameters()),
|
|
340
|
+
lr=self.lr,
|
|
341
|
+
)
|
|
342
|
+
model, optimizer = self.fabric.setup(model, optimizer)
|
|
343
|
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
344
|
+
optimizer=optimizer, T_max=self.num_steps, eta_min=self.lr * 0.1
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
model.train()
|
|
348
|
+
device = self.fabric.device
|
|
349
|
+
for step_idx in tqdm(
|
|
350
|
+
range(1, 1 + self.num_steps), "training", dynamic_ncols=True
|
|
351
|
+
):
|
|
352
|
+
# sample a preference ray
|
|
353
|
+
ray = torch.from_numpy(
|
|
354
|
+
np.random.dirichlet((self.alpha,) * num_objectives, 1)
|
|
355
|
+
.astype(np.float32)
|
|
356
|
+
.flatten()
|
|
357
|
+
).to(device)
|
|
358
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
359
|
+
|
|
360
|
+
loss = self.compute_loss(model, ray)
|
|
361
|
+
|
|
362
|
+
optimizer.zero_grad()
|
|
363
|
+
self.fabric.backward(loss)
|
|
364
|
+
optimizer.step()
|
|
365
|
+
|
|
366
|
+
lr_scheduler.step()
|
|
367
|
+
|
|
368
|
+
self.fabric.log("loss", loss.item(), step=step_idx)
|
|
369
|
+
|
|
370
|
+
if step_idx % self.save_interval == 0 or step_idx == self.num_steps:
|
|
371
|
+
ckpt_dir = Path(self.log_dir) / "checkpoints"
|
|
372
|
+
ckpt_dir.mkdir(exist_ok=True, parents=True)
|
|
373
|
+
self.fabric.save(
|
|
374
|
+
ckpt_dir / f"model_step={step_idx}.ckpt",
|
|
375
|
+
{"model": model},
|
|
376
|
+
)
|
|
377
|
+
return model
|
|
378
|
+
|
|
379
|
+
def evaluate(self, model):
|
|
380
|
+
results = defaultdict(list)
|
|
381
|
+
|
|
382
|
+
num_objectives = len(self.modelpool.model_names)
|
|
383
|
+
device = self.fabric.device
|
|
384
|
+
self.classification_heads = {
|
|
385
|
+
t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
if not lightning.fabric.is_wrapped(model):
|
|
389
|
+
model = self.fabric.setup_module(model)
|
|
390
|
+
model.eval()
|
|
391
|
+
|
|
392
|
+
if self.num_evaluation_samples == "equal_weight":
|
|
393
|
+
uniform_grid = np.array(
|
|
394
|
+
[[1 / num_objectives] * num_objectives], dtype=np.float32
|
|
395
|
+
)
|
|
396
|
+
else:
|
|
397
|
+
uniform_grid = generate_simplex_grid(
|
|
398
|
+
num_objectives, self.num_evaluation_samples
|
|
399
|
+
)
|
|
400
|
+
for ray_idx, ray in tqdm(enumerate(uniform_grid), "evaluating samples"):
|
|
401
|
+
results["ray_idx"].append(ray_idx)
|
|
402
|
+
# sample a preference ray
|
|
403
|
+
for i in range(len(ray)):
|
|
404
|
+
results[f"ray_{i}"].append(ray[i])
|
|
405
|
+
ray = torch.from_numpy(ray).to(device)
|
|
406
|
+
ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
|
|
407
|
+
|
|
408
|
+
accs = []
|
|
409
|
+
for dataset_idx, dataset_name in enumerate(
|
|
410
|
+
tqdm(
|
|
411
|
+
self.modelpool.test_dataset_names,
|
|
412
|
+
"evaluating datasets",
|
|
413
|
+
leave=False,
|
|
414
|
+
)
|
|
415
|
+
):
|
|
416
|
+
test_loader = self.test_loaders[dataset_name]
|
|
417
|
+
TOTAL_CORRECT = 0
|
|
418
|
+
TOTAL_COUNT = 0
|
|
419
|
+
for batch_idx, batch in enumerate(
|
|
420
|
+
pbar := tqdm(
|
|
421
|
+
test_loader,
|
|
422
|
+
f"evaluate {dataset_name}",
|
|
423
|
+
leave=False,
|
|
424
|
+
)
|
|
425
|
+
):
|
|
426
|
+
x, y = batch
|
|
427
|
+
|
|
428
|
+
features = model(x)
|
|
429
|
+
logits = self.classification_heads[dataset_name](features)
|
|
430
|
+
preds = logits.argmax(-1)
|
|
431
|
+
|
|
432
|
+
correct = (preds == y).sum().item()
|
|
433
|
+
TOTAL_CORRECT += correct
|
|
434
|
+
TOTAL_COUNT += len(y)
|
|
435
|
+
acc = TOTAL_CORRECT / TOTAL_COUNT
|
|
436
|
+
pbar.set_postfix_str(f"acc={acc:.2f}")
|
|
437
|
+
|
|
438
|
+
if self.quick_evaluation and batch_idx > 20:
|
|
439
|
+
break
|
|
440
|
+
results[dataset_name].append(acc)
|
|
441
|
+
accs.append(acc)
|
|
442
|
+
|
|
443
|
+
# compute the average accuracy
|
|
444
|
+
if "average" not in self.modelpool.test_dataset_names:
|
|
445
|
+
results["average"].append(np.mean(accs))
|
|
446
|
+
|
|
447
|
+
(df := pd.DataFrame(results)).to_csv(
|
|
448
|
+
Path(self.log_dir) / "result.csv", index=False
|
|
449
|
+
)
|
|
450
|
+
log.info(df)
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class PWEMoELinearScalarizationForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
454
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
455
|
+
loss = 0
|
|
456
|
+
for r, l in zip(ray, losses):
|
|
457
|
+
loss += r * l
|
|
458
|
+
return loss
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class PWEMoEExactParetoOptimalForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
|
|
462
|
+
epo_solver: Optional[EPOSolver] = None
|
|
463
|
+
|
|
464
|
+
def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
|
|
465
|
+
if self.epo_solver is None:
|
|
466
|
+
num_objectives = len(self.modelpool.model_names)
|
|
467
|
+
self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
|
|
468
|
+
epo_solver = self.epo_solver
|
|
469
|
+
|
|
470
|
+
losses = torch.stack(losses)
|
|
471
|
+
loss = epo_solver.get_weighted_loss(
|
|
472
|
+
losses,
|
|
473
|
+
ray,
|
|
474
|
+
tuple(filter(lambda p: p.requires_grad, model.parameters())),
|
|
475
|
+
)
|
|
476
|
+
return loss
|
|
@@ -13,6 +13,7 @@ from torch import Tensor, nn
|
|
|
13
13
|
from tqdm.autonotebook import tqdm
|
|
14
14
|
|
|
15
15
|
from fusion_bench.method import BaseAlgorithm
|
|
16
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
16
17
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
18
|
|
|
18
19
|
log = logging.getLogger(__name__)
|
|
@@ -279,7 +280,7 @@ def regmean_merging(
|
|
|
279
280
|
return merged_params
|
|
280
281
|
|
|
281
282
|
|
|
282
|
-
class RegMeanAlgorithm(BaseAlgorithm):
|
|
283
|
+
class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
283
284
|
_include_module_type = [nn.Linear]
|
|
284
285
|
_config_mapping = {
|
|
285
286
|
"num_regmean_examples": "num_regmean_examples",
|
|
@@ -342,24 +343,31 @@ class RegMeanAlgorithm(BaseAlgorithm):
|
|
|
342
343
|
)
|
|
343
344
|
assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
|
|
344
345
|
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
346
|
+
with (
|
|
347
|
+
self.profile("merging models"),
|
|
348
|
+
self.profile("computing regmean weights"),
|
|
349
|
+
):
|
|
350
|
+
regmean_weights = self.get_regmean_weights(
|
|
351
|
+
name,
|
|
352
|
+
model,
|
|
353
|
+
train_dataset=modelpool.load_train_dataset(name),
|
|
354
|
+
linear_modules_to_merge=linear_modules_to_merge,
|
|
355
|
+
)
|
|
356
|
+
models_to_merge_regmean_weights_list.append(regmean_weights)
|
|
357
|
+
|
|
358
|
+
with self.profile("merging models"):
|
|
359
|
+
# merging with regmean weights
|
|
360
|
+
merged_params = merging_with_regmean_weights(
|
|
361
|
+
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
362
|
+
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
363
|
+
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
364
|
+
weight_transpose=self.config.get("weight_transpose", True),
|
|
365
|
+
)
|
|
352
366
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
models_to_merge_param_dict=models_to_merge_param_dict,
|
|
356
|
-
models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
|
|
357
|
-
reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
|
|
358
|
-
weight_transpose=self.config.get("weight_transpose", True),
|
|
359
|
-
)
|
|
367
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
368
|
+
merged_model.load_state_dict(merged_params, strict=False)
|
|
360
369
|
|
|
361
|
-
|
|
362
|
-
merged_model.load_state_dict(merged_params, strict=False)
|
|
370
|
+
self.print_profile_summary()
|
|
363
371
|
return merged_model
|
|
364
372
|
|
|
365
373
|
def on_regmean_start(self):
|
|
@@ -442,16 +442,19 @@ class SmileUpscalingAlgorithm(
|
|
|
442
442
|
print_parameters(model)
|
|
443
443
|
return model
|
|
444
444
|
|
|
445
|
-
with self.profile("
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
445
|
+
with self.profile("loading model"):
|
|
446
|
+
# load models and move to GPU if available
|
|
447
|
+
with self.profile("load pretrained model"):
|
|
448
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
449
|
+
with self.profile("load fine-tuned model"):
|
|
450
|
+
finetuned_models = [
|
|
451
|
+
m
|
|
452
|
+
for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
453
|
+
]
|
|
454
|
+
|
|
455
|
+
if self.config.device == "cuda" and torch.cuda.is_available():
|
|
456
|
+
pretrained_model = pretrained_model.cuda()
|
|
457
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
455
458
|
|
|
456
459
|
with self.profile("merge model"):
|
|
457
460
|
model = self.merge(pretrained_model, finetuned_models)
|
|
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
|
|
|
85
85
|
|
|
86
86
|
if self.config.weights is not None:
|
|
87
87
|
# skip the test-time adaptation
|
|
88
|
+
merge_weight: torch.Tensor = torch.load(self.config.weights)
|
|
89
|
+
module.merge_weight.data = merge_weight.to(
|
|
90
|
+
device=module.merge_weight.device
|
|
91
|
+
)
|
|
88
92
|
merged_model = copy.deepcopy(module.merge_and_unload())
|
|
93
|
+
# setup the zero-shot classification head
|
|
94
|
+
self.on_test_time_adaptation_start()
|
|
95
|
+
|
|
89
96
|
else:
|
|
90
97
|
with self.profile("test-time adaptation"):
|
|
91
98
|
module = self.test_time_adaptation(module)
|
|
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
|
|
|
6
6
|
|
|
7
7
|
import logging
|
|
8
8
|
from copy import deepcopy
|
|
9
|
-
from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
|
|
9
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
from torch import nn
|
|
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
)
|
|
22
|
-
from fusion_bench.utils.type import StateDictType
|
|
22
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
23
23
|
|
|
24
24
|
log = logging.getLogger(__name__)
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
@torch.no_grad()
|
|
28
28
|
def task_arithmetic_merge(
|
|
29
|
-
pretrained_model:
|
|
30
|
-
finetuned_models: List[
|
|
29
|
+
pretrained_model: TorchModelType,
|
|
30
|
+
finetuned_models: List[TorchModelType],
|
|
31
31
|
scaling_factor: float,
|
|
32
32
|
inplace: bool = True,
|
|
33
|
-
) ->
|
|
33
|
+
) -> TorchModelType:
|
|
34
34
|
"""
|
|
35
35
|
Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
|
|
36
36
|
|
|
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
|
|
|
46
46
|
"""
|
|
47
47
|
if not inplace:
|
|
48
48
|
pretrained_model = deepcopy(pretrained_model)
|
|
49
|
-
task_vector: StateDictType = None
|
|
49
|
+
task_vector: Optional[StateDictType] = None
|
|
50
50
|
# Calculate the total task vector
|
|
51
51
|
for model in finetuned_models:
|
|
52
52
|
if task_vector is None:
|
|
53
|
+
# calculate the task vector for the first model
|
|
53
54
|
task_vector = state_dict_sub(
|
|
54
55
|
model.state_dict(keep_vars=True),
|
|
55
56
|
pretrained_model.state_dict(keep_vars=True),
|
|
56
57
|
)
|
|
57
58
|
else:
|
|
59
|
+
# calculate the task vector for the remaining models
|
|
58
60
|
task_vector = state_dict_add(
|
|
59
61
|
task_vector,
|
|
60
62
|
state_dict_sub(
|