fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +14 -5
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/pylogger.py +28 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
fusion_bench/method/__init__.py
CHANGED
|
@@ -67,6 +67,7 @@ _import_structure = {
|
|
|
67
67
|
"CLIPTaskWiseGossipAlgorithm",
|
|
68
68
|
"FlanT5LayerWiseGossipAlgorithm",
|
|
69
69
|
],
|
|
70
|
+
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
70
71
|
# plug-and-play model merging methods
|
|
71
72
|
"concrete_subspace": [
|
|
72
73
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -103,6 +104,7 @@ _import_structure = {
|
|
|
103
104
|
"RandomPruningForLlama",
|
|
104
105
|
"MagnitudePruningForLlama",
|
|
105
106
|
"WandaPruningForLlama",
|
|
107
|
+
"SparseGPTPruningForLlama",
|
|
106
108
|
],
|
|
107
109
|
"sparselo": [
|
|
108
110
|
"IterativeSparseLoForLlama",
|
|
@@ -141,6 +143,7 @@ if TYPE_CHECKING:
|
|
|
141
143
|
WeightedEnsembleAlgorithm,
|
|
142
144
|
)
|
|
143
145
|
from .fisher_merging import FisherMergingForCLIPVisionModel
|
|
146
|
+
from .fw_merging import FrankWolfeHardAlgorithm, FrankWolfeSoftAlgorithm
|
|
144
147
|
from .gossip import (
|
|
145
148
|
CLIPLayerWiseGossipAlgorithm,
|
|
146
149
|
CLIPTaskWiseGossipAlgorithm,
|
|
@@ -172,6 +175,7 @@ if TYPE_CHECKING:
|
|
|
172
175
|
MagnitudeDiffPruningAlgorithm,
|
|
173
176
|
MagnitudePruningForLlama,
|
|
174
177
|
RandomPruningForLlama,
|
|
178
|
+
SparseGPTPruningForLlama,
|
|
175
179
|
WandaPruningForLlama,
|
|
176
180
|
)
|
|
177
181
|
from .pwe_moe import (
|
|
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
|
29
29
|
get_layer_wise_weights,
|
|
30
30
|
)
|
|
31
31
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
32
|
-
from fusion_bench.utils.
|
|
32
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
33
33
|
|
|
34
34
|
from .entropy_loss import entropy_loss
|
|
35
35
|
from .min_norm_solvers import MinNormSolver
|
|
@@ -29,7 +29,7 @@ from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
|
29
29
|
get_layer_wise_weights,
|
|
30
30
|
)
|
|
31
31
|
from fusion_bench.utils.data import InfiniteDataLoader, load_tensor_from_file
|
|
32
|
-
from fusion_bench.utils.
|
|
32
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
33
33
|
|
|
34
34
|
from .entropy_loss import entropy_loss
|
|
35
35
|
from .min_norm_solvers import MinNormSolver
|
|
@@ -23,7 +23,7 @@ from fusion_bench.mixins import CLIPClassificationMixin
|
|
|
23
23
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
24
24
|
from fusion_bench.utils import timeit_context
|
|
25
25
|
from fusion_bench.utils.data import InfiniteDataLoader
|
|
26
|
-
from fusion_bench.utils.
|
|
26
|
+
from fusion_bench.utils.instantiate_utils import instantiate
|
|
27
27
|
|
|
28
28
|
from .warppers.dawe_model import DataAdaptiveWeightEnsemblingCLIPVisionModel
|
|
29
29
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Optional
|
|
3
3
|
|
|
4
|
+
from transformers import PreTrainedModel
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from fusion_bench.modelpool.causal_lm.causal_lm import
|
|
7
|
+
from fusion_bench.modelpool.causal_lm.causal_lm import CausalLMPool
|
|
7
8
|
from fusion_bench.utils import timeit_context
|
|
8
9
|
|
|
9
10
|
from .depth_upscaling import DepthUpscalingAlgorithm
|
|
@@ -46,7 +47,7 @@ class DepthUpscalingForLlama(DepthUpscalingAlgorithm):
|
|
|
46
47
|
if self.model_save_path is not None:
|
|
47
48
|
tokenizer = modelpool.load_tokenizer()
|
|
48
49
|
|
|
49
|
-
model:
|
|
50
|
+
model: PreTrainedModel = modelpool.load_pretrained_or_first_model()
|
|
50
51
|
model.model.layers = super().run(model.model.layers)
|
|
51
52
|
model.config.num_hidden_layers = len(model.model.layers)
|
|
52
53
|
|
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script contains the general implementation of the Task Arithmetic method.
|
|
3
|
+
|
|
4
|
+
http://arxiv.org/abs/2212.04089
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from abc import abstractmethod
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from functools import partial
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypeVar, Union
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
18
|
+
from omegaconf import DictConfig
|
|
19
|
+
from torch import Tensor, nn
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
from tqdm.autonotebook import tqdm
|
|
22
|
+
|
|
23
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
24
|
+
from fusion_bench.compat.modelpool import HuggingFaceClipVisionPool, ModelPool
|
|
25
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
26
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
27
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
28
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
29
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
30
|
+
LayerWiseMergedModel,
|
|
31
|
+
get_layer_wise_weights,
|
|
32
|
+
)
|
|
33
|
+
from fusion_bench.utils.data import load_tensor_from_file
|
|
34
|
+
from fusion_bench.utils.type import TorchModelType
|
|
35
|
+
|
|
36
|
+
from .utils import *
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
|
|
40
|
+
|
|
41
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
42
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
43
|
+
from fusion_bench.utils import instantiate
|
|
44
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
45
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
46
|
+
state_dict_add,
|
|
47
|
+
state_dict_mul,
|
|
48
|
+
state_dict_sub,
|
|
49
|
+
)
|
|
50
|
+
from fusion_bench.utils.type import StateDictType
|
|
51
|
+
|
|
52
|
+
log = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@torch.no_grad()
|
|
56
|
+
def task_arithmetic_merge(
|
|
57
|
+
pretrained_model: nn.Module,
|
|
58
|
+
finetuned_models: List[Dict[str, Tensor]],
|
|
59
|
+
scaling_factor: float,
|
|
60
|
+
inplace: bool = True,
|
|
61
|
+
) -> nn.Module:
|
|
62
|
+
"""
|
|
63
|
+
Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added.
|
|
67
|
+
finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
|
|
68
|
+
scaling_factor (float): A factor by which the task vectors will be scaled before merging.
|
|
69
|
+
inplace (bool, optional): If True, the pre-trained model will be modified in place.
|
|
70
|
+
If False, a copy of the pre-trained model will be modified. Defaults to True.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
nn.Module: The pre-trained model with the merged task vectors.
|
|
74
|
+
"""
|
|
75
|
+
if not inplace:
|
|
76
|
+
pretrained_model = deepcopy(pretrained_model)
|
|
77
|
+
if isinstance(finetuned_models[0], nn.Module):
|
|
78
|
+
finetuned_models = [
|
|
79
|
+
deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
|
|
80
|
+
]
|
|
81
|
+
task_vector: StateDictType = None
|
|
82
|
+
# Calculate the total task vector
|
|
83
|
+
for model in finetuned_models:
|
|
84
|
+
if task_vector is None:
|
|
85
|
+
task_vector = state_dict_sub(
|
|
86
|
+
model,
|
|
87
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
task_vector = state_dict_add(
|
|
91
|
+
task_vector,
|
|
92
|
+
state_dict_sub(
|
|
93
|
+
model,
|
|
94
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
95
|
+
),
|
|
96
|
+
)
|
|
97
|
+
# scale the task vector
|
|
98
|
+
task_vector = state_dict_mul(task_vector, scaling_factor)
|
|
99
|
+
# add the task vector to the pretrained model
|
|
100
|
+
state_dict = state_dict_add(
|
|
101
|
+
pretrained_model.state_dict(keep_vars=True), task_vector
|
|
102
|
+
)
|
|
103
|
+
pretrained_model.load_state_dict(state_dict)
|
|
104
|
+
return pretrained_model
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@torch.no_grad()
|
|
108
|
+
def ties_merge(
|
|
109
|
+
pretrained_model: nn.Module,
|
|
110
|
+
finetuned_models: List[Dict[str, Tensor]],
|
|
111
|
+
scaling_factor: float,
|
|
112
|
+
threshold: float,
|
|
113
|
+
) -> nn.Module:
|
|
114
|
+
remove_keys = []
|
|
115
|
+
merge_func = "sum"
|
|
116
|
+
if isinstance(finetuned_models[0], nn.Module):
|
|
117
|
+
finetuned_models = [
|
|
118
|
+
deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
ptm_check = pretrained_model.state_dict(keep_vars=True)
|
|
122
|
+
|
|
123
|
+
# Compute the task vectors
|
|
124
|
+
flat_ft = torch.vstack(
|
|
125
|
+
[state_dict_to_vector(check, remove_keys) for check in finetuned_models]
|
|
126
|
+
)
|
|
127
|
+
flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
|
|
128
|
+
tv_flat_checks = flat_ft - flat_ptm
|
|
129
|
+
|
|
130
|
+
# Perform TIES Merging
|
|
131
|
+
merged_tv = ties_merging(
|
|
132
|
+
tv_flat_checks,
|
|
133
|
+
reset_thresh=threshold,
|
|
134
|
+
merge_func=merge_func,
|
|
135
|
+
)
|
|
136
|
+
merged_check = flat_ptm + scaling_factor * merged_tv
|
|
137
|
+
merged_state_dict = vector_to_state_dict(
|
|
138
|
+
merged_check, ptm_check, remove_keys=remove_keys
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Load the merged state dict into the pretrained model
|
|
142
|
+
pretrained_model.load_state_dict(merged_state_dict)
|
|
143
|
+
return pretrained_model
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def entropy_loss(logits: Tensor, pred=None, eps: float = 1e-8) -> Tensor:
|
|
147
|
+
"""
|
|
148
|
+
Compute the entropy loss of a set of logits.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
152
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Tensor: The entropy loss of the logits.
|
|
156
|
+
"""
|
|
157
|
+
# Ensure the logits tensor has 2 dimensions
|
|
158
|
+
assert (
|
|
159
|
+
logits.dim() == 2
|
|
160
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
161
|
+
|
|
162
|
+
# Compute the softmax probabilities
|
|
163
|
+
probs = torch.softmax(logits, dim=-1)
|
|
164
|
+
|
|
165
|
+
# Compute the entropy loss
|
|
166
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class FrankWolfeHardAlgorithm(
|
|
170
|
+
CLIPClassificationMixin,
|
|
171
|
+
ModelFusionAlgorithm,
|
|
172
|
+
SimpleProfilerMixin,
|
|
173
|
+
):
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
merge_fn: str,
|
|
178
|
+
step_size: float,
|
|
179
|
+
max_iters: int,
|
|
180
|
+
dataset_size: int,
|
|
181
|
+
tasks: List[str] = [],
|
|
182
|
+
granularity: str = "task",
|
|
183
|
+
max_num_models: int = 100,
|
|
184
|
+
loss_fn: str = "cross_entropy",
|
|
185
|
+
init_weight: str = "",
|
|
186
|
+
scaling_factor: float = 1.0,
|
|
187
|
+
threshold: int = 20,
|
|
188
|
+
**kwargs,
|
|
189
|
+
):
|
|
190
|
+
"""
|
|
191
|
+
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
195
|
+
"""
|
|
196
|
+
self.merger = merge_fn
|
|
197
|
+
if merge_fn == "task_arithmetic":
|
|
198
|
+
self.merge_fn = task_arithmetic_merge
|
|
199
|
+
elif merge_fn == "ties":
|
|
200
|
+
self.merge_fn = partial(ties_merge, threshold=threshold)
|
|
201
|
+
# elif merge_fn == "concrete_ta":
|
|
202
|
+
# self.merge_fn = ConcreteTaskArithmeticAlgorithmForCLIP(
|
|
203
|
+
# instantiate(OmegaConf.load("config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml"))
|
|
204
|
+
# )
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError(f"Unsupported merge_fn: {merge_fn}")
|
|
207
|
+
self.scaling_factor = scaling_factor
|
|
208
|
+
|
|
209
|
+
self.init_weight = init_weight
|
|
210
|
+
self.step_size = step_size
|
|
211
|
+
self.max_iters = max_iters
|
|
212
|
+
self.granularity = granularity
|
|
213
|
+
self.loss_fn = loss_fn
|
|
214
|
+
self.tasks = tasks
|
|
215
|
+
self.dataset_size = dataset_size
|
|
216
|
+
self.max_num_models = max_num_models
|
|
217
|
+
super().__init__(**kwargs)
|
|
218
|
+
|
|
219
|
+
def on_frank_wolfe_iteration_start(self):
|
|
220
|
+
self.setup_zero_shot_classification_head()
|
|
221
|
+
|
|
222
|
+
@functools.cache
|
|
223
|
+
def get_shuffled_loader_iter(self, task: str):
|
|
224
|
+
if self.loss_fn == "cross_entropy":
|
|
225
|
+
# get dataloader kwargs
|
|
226
|
+
dataloader_kwargs = self._dataloader_kwargs.copy()
|
|
227
|
+
dataloader_kwargs["shuffle"] = True
|
|
228
|
+
dataloader_kwargs["batch_size"] = 1
|
|
229
|
+
|
|
230
|
+
# get the test dataset
|
|
231
|
+
clip_dataset = CLIPDataset(
|
|
232
|
+
self.modelpool.load_train_dataset(task), self.clip_processor
|
|
233
|
+
)
|
|
234
|
+
# create the dataloader
|
|
235
|
+
loader = DataLoader(clip_dataset, **dataloader_kwargs)
|
|
236
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
237
|
+
return iter(InfiniteDataLoader(loader))
|
|
238
|
+
elif self.loss_fn == "entropy":
|
|
239
|
+
return super().get_shuffled_test_loader_iter(
|
|
240
|
+
task,
|
|
241
|
+
batch_size=1,
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
raise ValueError(f"Unsupported loss function: {self.loss_fn}")
|
|
245
|
+
|
|
246
|
+
def frank_wolfe_iteration(self, merged_model):
|
|
247
|
+
|
|
248
|
+
merged_model.train()
|
|
249
|
+
# zero the gradients
|
|
250
|
+
for name, param in merged_model.named_parameters():
|
|
251
|
+
param.requires_grad = True
|
|
252
|
+
param.grad = None
|
|
253
|
+
|
|
254
|
+
if self.loss_fn == "cross_entropy":
|
|
255
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
256
|
+
elif self.loss_fn == "entropy":
|
|
257
|
+
loss_fn = entropy_loss
|
|
258
|
+
avg_loss = defaultdict(list)
|
|
259
|
+
tasks = self.tasks if self.tasks else self.modelpool.model_names
|
|
260
|
+
for task in tasks:
|
|
261
|
+
log.info(f"Processing task {task}")
|
|
262
|
+
for _ in range(self.dataset_size):
|
|
263
|
+
with self.profile("data loading"):
|
|
264
|
+
batch = next(self.get_shuffled_loader_iter(task))
|
|
265
|
+
with self.profile("forward pass"):
|
|
266
|
+
logits = self.compute_logits(merged_model, batch[0], task)
|
|
267
|
+
loss = loss_fn(logits, batch[1]) / (
|
|
268
|
+
self.dataset_size * len(self.modelpool.model_names)
|
|
269
|
+
)
|
|
270
|
+
with self.profile("backward pass"):
|
|
271
|
+
# self.fabric.backward(loss, retain_graph=True)
|
|
272
|
+
loss.backward()
|
|
273
|
+
avg_loss[task].append(loss.item())
|
|
274
|
+
|
|
275
|
+
# calculate the loss
|
|
276
|
+
avg_loss = {
|
|
277
|
+
task: sum(losses) / len(losses) for task, losses in avg_loss.items()
|
|
278
|
+
}
|
|
279
|
+
log.info(
|
|
280
|
+
f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
gradients = {
|
|
284
|
+
name: param.grad.clone().to("cpu")
|
|
285
|
+
for name, param in merged_model.named_parameters()
|
|
286
|
+
if param.requires_grad
|
|
287
|
+
}
|
|
288
|
+
for name, param in merged_model.named_parameters():
|
|
289
|
+
param.grad = None
|
|
290
|
+
merged_model.eval()
|
|
291
|
+
|
|
292
|
+
return gradients
|
|
293
|
+
|
|
294
|
+
def frank_wolfe_selection(
|
|
295
|
+
self, gradients, checkpoints, model_to_merge_names={}, type="task"
|
|
296
|
+
):
|
|
297
|
+
assert type in [
|
|
298
|
+
"task",
|
|
299
|
+
"layer",
|
|
300
|
+
], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
|
|
301
|
+
min_inner_product = float("inf")
|
|
302
|
+
min_model = None
|
|
303
|
+
min_model_name = None
|
|
304
|
+
log_dict = {}
|
|
305
|
+
if type == "task":
|
|
306
|
+
for model_name, model_to_merge in checkpoints.items():
|
|
307
|
+
model_to_merge = model_to_merge.to("cpu").state_dict()
|
|
308
|
+
inner_product_sum = 0
|
|
309
|
+
for param_name, param_value in model_to_merge.items():
|
|
310
|
+
# caclulate consine similarity
|
|
311
|
+
grad = gradients[param_name]
|
|
312
|
+
ckpt = model_to_merge[param_name]
|
|
313
|
+
param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
|
|
314
|
+
torch.norm(grad) * torch.norm(ckpt)
|
|
315
|
+
)
|
|
316
|
+
inner_product_sum += param_alignment
|
|
317
|
+
log_dict[model_name] = inner_product_sum.item()
|
|
318
|
+
if (
|
|
319
|
+
inner_product_sum < min_inner_product
|
|
320
|
+
and model_name not in model_to_merge_names
|
|
321
|
+
):
|
|
322
|
+
min_inner_product = inner_product_sum
|
|
323
|
+
min_model = deepcopy(model_to_merge)
|
|
324
|
+
min_model_name = model_name
|
|
325
|
+
else:
|
|
326
|
+
min_model = {}
|
|
327
|
+
min_inner_product = {}
|
|
328
|
+
min_idx = {}
|
|
329
|
+
min_model_name = {}
|
|
330
|
+
for model_name, model_to_merge in checkpoints.items():
|
|
331
|
+
model_to_merge = model_to_merge.to("cpu").state_dict()
|
|
332
|
+
for param_name, param_value in model_to_merge.items():
|
|
333
|
+
# caclulate consine similarity
|
|
334
|
+
grad = gradients[param_name]
|
|
335
|
+
ckpt = model_to_merge[param_name]
|
|
336
|
+
param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
|
|
337
|
+
torch.norm(grad) * torch.norm(ckpt)
|
|
338
|
+
)
|
|
339
|
+
if (
|
|
340
|
+
param_name not in min_inner_product
|
|
341
|
+
or param_alignment < min_inner_product[param_name]
|
|
342
|
+
) and model_name not in model_to_merge_names[param_name]:
|
|
343
|
+
min_inner_product[param_name] = param_alignment
|
|
344
|
+
# if min_inner_product[param_name] < 0:
|
|
345
|
+
min_model[param_name] = param_value
|
|
346
|
+
min_idx[param_name] = model_name
|
|
347
|
+
min_model_name[param_name] = model_name
|
|
348
|
+
# else:
|
|
349
|
+
# min_model[param_name] = torch.zeros_like(param_value)
|
|
350
|
+
min_inner_product = sum(min_inner_product.values())
|
|
351
|
+
log_dict = {model_name: 0 for model_name in checkpoints.keys()}
|
|
352
|
+
for k in min_idx.values():
|
|
353
|
+
log_dict[k] += 1
|
|
354
|
+
|
|
355
|
+
return min_model, min_model_name, min_inner_product, log_dict
|
|
356
|
+
|
|
357
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
358
|
+
log.info("Fusing models using FW merging.")
|
|
359
|
+
self.modelpool = modelpool
|
|
360
|
+
self.log_hyperparams(self.config)
|
|
361
|
+
self.on_frank_wolfe_iteration_start()
|
|
362
|
+
|
|
363
|
+
assert modelpool.has_pretrained, "Pretrained model is required."
|
|
364
|
+
finetuned_models = {
|
|
365
|
+
name: modelpool.load_model(name)
|
|
366
|
+
for name in modelpool.model_names[: self.max_num_models]
|
|
367
|
+
}
|
|
368
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
369
|
+
|
|
370
|
+
if self.init_weight:
|
|
371
|
+
if self.init_weight == "base":
|
|
372
|
+
log.info("Initializing the merged model with the base model")
|
|
373
|
+
merged_model = pretrained_model
|
|
374
|
+
else:
|
|
375
|
+
log.info("Initializing the merged model with the initial weight")
|
|
376
|
+
if isinstance(self.init_weight, str):
|
|
377
|
+
# self.config.weights is a path to a saved tensor
|
|
378
|
+
layer_wise_weight = load_tensor_from_file(self.init_weight)
|
|
379
|
+
else:
|
|
380
|
+
raise ValueError(f"Unsupported weights format: {self.init_weight}")
|
|
381
|
+
|
|
382
|
+
merged_model = LayerWiseMergedModel(
|
|
383
|
+
layer_wise_weight=layer_wise_weight,
|
|
384
|
+
pretrained_model=modelpool.load_model("_pretrained_"),
|
|
385
|
+
finetuned_models=list(finetuned_models.values()),
|
|
386
|
+
clamp_weights=False,
|
|
387
|
+
tie_weights=True,
|
|
388
|
+
strict=False,
|
|
389
|
+
).cuda()
|
|
390
|
+
merged_model = merged_model.merge_and_unload()
|
|
391
|
+
else:
|
|
392
|
+
log.info("Initializing the merged model with merge function")
|
|
393
|
+
merged_model = self.merge_fn(
|
|
394
|
+
pretrained_model=modelpool.load_model("_pretrained_"),
|
|
395
|
+
finetuned_models=list(finetuned_models.values()),
|
|
396
|
+
scaling_factor=self.scaling_factor,
|
|
397
|
+
).cuda()
|
|
398
|
+
# merged_model = self.fabric.setup(merged_model)
|
|
399
|
+
|
|
400
|
+
initial_model = modelpool.load_model("_pretrained_")
|
|
401
|
+
initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
|
|
402
|
+
finetuned_models["initial"] = initial_model
|
|
403
|
+
for step_idx in (
|
|
404
|
+
pbar := tqdm(
|
|
405
|
+
range(self.max_iters if not self.is_debug_mode else 1),
|
|
406
|
+
("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
|
|
407
|
+
dynamic_ncols=True,
|
|
408
|
+
)
|
|
409
|
+
):
|
|
410
|
+
torch.cuda.empty_cache()
|
|
411
|
+
torch.set_grad_enabled(True)
|
|
412
|
+
gradients = self.frank_wolfe_iteration(merged_model.cuda())
|
|
413
|
+
torch.set_grad_enabled(False)
|
|
414
|
+
grad_norm = torch.norm(
|
|
415
|
+
torch.stack([torch.norm(g) for g in gradients.values()])
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
model_to_merge_names = (
|
|
419
|
+
[]
|
|
420
|
+
if self.granularity == "task"
|
|
421
|
+
else {name: [] for name in merged_model.state_dict().keys()}
|
|
422
|
+
)
|
|
423
|
+
min_model, min_model_name, min_alignment, chosen_model = (
|
|
424
|
+
self.frank_wolfe_selection(
|
|
425
|
+
gradients,
|
|
426
|
+
finetuned_models,
|
|
427
|
+
model_to_merge_names=model_to_merge_names,
|
|
428
|
+
type=self.granularity,
|
|
429
|
+
)
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# Determine step size
|
|
433
|
+
step = 2 / (step_idx + 2) * self.step_size
|
|
434
|
+
|
|
435
|
+
# print iteration information
|
|
436
|
+
log.info(
|
|
437
|
+
f"Iteration {step_idx+1}, Task Vector: {min_model_name}, Gradient Norm: {grad_norm:.6f}, Inner Products: {min_alignment:.6f}, Chosen Model: {chosen_model}"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
merged_model = self.merge_fn(
|
|
441
|
+
pretrained_model=merged_model.to("cpu"),
|
|
442
|
+
finetuned_models=[min_model],
|
|
443
|
+
scaling_factor=step * self.scaling_factor,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
torch.set_grad_enabled(False)
|
|
447
|
+
merged_model = merged_model.cuda().eval()
|
|
448
|
+
return merged_model
|