fusion-bench 0.2.15__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/programs/fabric_fusion_program.py +5 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/METADATA +6 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/RECORD +77 -21
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This script contains the general implementation of the Task Arithmetic method.
|
|
3
|
+
|
|
4
|
+
http://arxiv.org/abs/2212.04089
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from abc import abstractmethod
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from functools import partial
|
|
14
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, TypeVar, Union
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
18
|
+
from omegaconf import DictConfig
|
|
19
|
+
from torch import Tensor, nn
|
|
20
|
+
from torch.utils.data import DataLoader
|
|
21
|
+
from tqdm.autonotebook import tqdm
|
|
22
|
+
|
|
23
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
24
|
+
from fusion_bench.compat.modelpool import HuggingFaceClipVisionPool, ModelPool
|
|
25
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
26
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
27
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
28
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
29
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
30
|
+
LayerWiseMergedModel,
|
|
31
|
+
get_layer_wise_weights,
|
|
32
|
+
)
|
|
33
|
+
from fusion_bench.utils.data import load_tensor_from_file
|
|
34
|
+
from fusion_bench.utils.type import TorchModelType
|
|
35
|
+
|
|
36
|
+
from .utils import *
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
|
|
40
|
+
|
|
41
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
42
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
43
|
+
from fusion_bench.utils import instantiate
|
|
44
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
45
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
46
|
+
state_dict_add,
|
|
47
|
+
state_dict_mul,
|
|
48
|
+
state_dict_sub,
|
|
49
|
+
)
|
|
50
|
+
from fusion_bench.utils.type import StateDictType
|
|
51
|
+
|
|
52
|
+
log = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def projection_simplex_sort(v, z=1):
|
|
56
|
+
# print(v.shape)
|
|
57
|
+
n_features = v.shape[0] # Get the number of elements in v
|
|
58
|
+
u, _ = torch.sort(v, descending=True) # Sort v in descending order
|
|
59
|
+
cssv = torch.cumsum(u, dim=0) - z # Compute cumulative sum and subtract z
|
|
60
|
+
ind = torch.arange(
|
|
61
|
+
1, n_features + 1, dtype=torch.long, device=v.device
|
|
62
|
+
) # Create index tensor (1 to n_features)
|
|
63
|
+
cond = u - cssv / ind > 0 # Condition to find rho
|
|
64
|
+
if cond.any(): # Ensure there is at least one valid rho
|
|
65
|
+
rho = ind[cond][-1] # Find the largest index satisfying the condition
|
|
66
|
+
theta = cssv[rho - 1] / rho # Compute the correct threshold theta
|
|
67
|
+
else:
|
|
68
|
+
theta = 0 # Default case when all values are zero or negative
|
|
69
|
+
w = torch.clamp(
|
|
70
|
+
v - theta, min=0
|
|
71
|
+
) # Compute the projected vector, ensuring non-negativity
|
|
72
|
+
return w
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@torch.no_grad()
|
|
76
|
+
def task_arithmetic_merge(
|
|
77
|
+
pretrained_model: nn.Module,
|
|
78
|
+
finetuned_models: List[Dict[str, Tensor]],
|
|
79
|
+
scaling_factor: float,
|
|
80
|
+
inplace: bool = True,
|
|
81
|
+
) -> nn.Module:
|
|
82
|
+
"""
|
|
83
|
+
Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
pretrained_model (nn.Module): The pre-trained model to which the task vectors will be added.
|
|
87
|
+
finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
|
|
88
|
+
scaling_factor (float): A factor by which the task vectors will be scaled before merging.
|
|
89
|
+
inplace (bool, optional): If True, the pre-trained model will be modified in place.
|
|
90
|
+
If False, a copy of the pre-trained model will be modified. Defaults to True.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
nn.Module: The pre-trained model with the merged task vectors.
|
|
94
|
+
"""
|
|
95
|
+
if not inplace:
|
|
96
|
+
pretrained_model = deepcopy(pretrained_model)
|
|
97
|
+
if isinstance(finetuned_models[0], nn.Module):
|
|
98
|
+
finetuned_models = [
|
|
99
|
+
deepcopy(model.state_dict(keep_vars=True)) for model in finetuned_models
|
|
100
|
+
]
|
|
101
|
+
task_vector: StateDictType = None
|
|
102
|
+
# Calculate the total task vector
|
|
103
|
+
for model in finetuned_models:
|
|
104
|
+
if task_vector is None:
|
|
105
|
+
task_vector = state_dict_sub(
|
|
106
|
+
model,
|
|
107
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
task_vector = state_dict_add(
|
|
111
|
+
task_vector,
|
|
112
|
+
state_dict_sub(
|
|
113
|
+
model,
|
|
114
|
+
pretrained_model.state_dict(keep_vars=True),
|
|
115
|
+
),
|
|
116
|
+
)
|
|
117
|
+
# scale the task vector
|
|
118
|
+
task_vector = state_dict_mul(task_vector, scaling_factor)
|
|
119
|
+
# add the task vector to the pretrained model
|
|
120
|
+
state_dict = state_dict_add(
|
|
121
|
+
pretrained_model.state_dict(keep_vars=True), task_vector
|
|
122
|
+
)
|
|
123
|
+
pretrained_model.load_state_dict(state_dict)
|
|
124
|
+
return pretrained_model
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def entropy_loss(logits: Tensor, pred=None, eps: float = 1e-8) -> Tensor:
|
|
128
|
+
"""
|
|
129
|
+
Compute the entropy loss of a set of logits.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
logits (Tensor): The logits to compute the entropy loss of.
|
|
133
|
+
eps (float): A small value to avoid log(0). Default is 1e-8.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Tensor: The entropy loss of the logits.
|
|
137
|
+
"""
|
|
138
|
+
# Ensure the logits tensor has 2 dimensions
|
|
139
|
+
assert (
|
|
140
|
+
logits.dim() == 2
|
|
141
|
+
), f"Expected logits to have 2 dimensions, found {logits.dim()}, {logits.size()=}"
|
|
142
|
+
|
|
143
|
+
# Compute the softmax probabilities
|
|
144
|
+
probs = torch.softmax(logits, dim=-1)
|
|
145
|
+
|
|
146
|
+
# Compute the entropy loss
|
|
147
|
+
return -torch.sum(probs * torch.log(probs + eps), dim=-1).mean()
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class FrankWolfeSoftAlgorithm(
|
|
151
|
+
CLIPClassificationMixin,
|
|
152
|
+
ModelFusionAlgorithm,
|
|
153
|
+
SimpleProfilerMixin,
|
|
154
|
+
):
|
|
155
|
+
def __init__(
|
|
156
|
+
self,
|
|
157
|
+
max_iters: int,
|
|
158
|
+
dataset_size: int,
|
|
159
|
+
ada_iters: int,
|
|
160
|
+
ada_coeff: float,
|
|
161
|
+
merge_fn: str,
|
|
162
|
+
granularity: str = "task",
|
|
163
|
+
max_num_models: int = 100,
|
|
164
|
+
step_size: float = 0.3,
|
|
165
|
+
tasks: List[str] = [],
|
|
166
|
+
init_weight: str = "",
|
|
167
|
+
ada_loss="entropy_loss",
|
|
168
|
+
**kwargs,
|
|
169
|
+
):
|
|
170
|
+
"""
|
|
171
|
+
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
step_size (int): The factor by which the task vectors will be scaled before merging.
|
|
175
|
+
"""
|
|
176
|
+
self.merge_fn = merge_fn
|
|
177
|
+
|
|
178
|
+
self.init_weight = init_weight
|
|
179
|
+
self.max_iters = max_iters
|
|
180
|
+
self.ada_iters = ada_iters
|
|
181
|
+
self.ada_coeff = ada_coeff
|
|
182
|
+
self.granularity = granularity
|
|
183
|
+
self.tasks = tasks
|
|
184
|
+
self.step_size = step_size
|
|
185
|
+
self.dataset_size = dataset_size
|
|
186
|
+
self.max_num_models = max_num_models
|
|
187
|
+
self.ada_loss = ada_loss
|
|
188
|
+
super().__init__(**kwargs)
|
|
189
|
+
|
|
190
|
+
def on_frank_wolfe_iteration_start(self):
|
|
191
|
+
self.setup_zero_shot_classification_head()
|
|
192
|
+
|
|
193
|
+
@functools.cache
|
|
194
|
+
def get_shuffled_train_loader_iter(self, task: str, batch_size: int = 1):
|
|
195
|
+
# get dataloader kwargs
|
|
196
|
+
dataloader_kwargs = self._dataloader_kwargs.copy()
|
|
197
|
+
dataloader_kwargs["shuffle"] = True
|
|
198
|
+
dataloader_kwargs["batch_size"] = batch_size
|
|
199
|
+
|
|
200
|
+
# get the test dataset
|
|
201
|
+
clip_dataset = CLIPDataset(
|
|
202
|
+
self.modelpool.load_train_dataset(task), self.clip_processor
|
|
203
|
+
)
|
|
204
|
+
# create the dataloader
|
|
205
|
+
loader = DataLoader(clip_dataset, **dataloader_kwargs)
|
|
206
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
207
|
+
return iter(InfiniteDataLoader(loader))
|
|
208
|
+
|
|
209
|
+
@functools.cache
|
|
210
|
+
def get_shuffled_test_loader_iter(self, task: str, batch_size: int = 1):
|
|
211
|
+
return super().get_shuffled_test_loader_iter(task, batch_size=batch_size)
|
|
212
|
+
|
|
213
|
+
def run_adamerging(self, module):
|
|
214
|
+
use_entropy_loss = self.ada_loss == "entropy_loss"
|
|
215
|
+
|
|
216
|
+
optimizer = torch.optim.Adam([module.merge_weight], lr=1e-3)
|
|
217
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
218
|
+
module.train()
|
|
219
|
+
for step_idx in (
|
|
220
|
+
pbar := tqdm(
|
|
221
|
+
range(self.ada_iters),
|
|
222
|
+
"AdaMerging (2/2)",
|
|
223
|
+
dynamic_ncols=True,
|
|
224
|
+
disable=not self.fabric.is_global_zero,
|
|
225
|
+
)
|
|
226
|
+
):
|
|
227
|
+
with self.profile("merge weights"):
|
|
228
|
+
module.merge_weights()
|
|
229
|
+
|
|
230
|
+
metrics = {}
|
|
231
|
+
total_loss = None
|
|
232
|
+
tasks = self.modelpool.model_names if self.tasks == [] else self.tasks
|
|
233
|
+
if not use_entropy_loss:
|
|
234
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
235
|
+
for task in tasks:
|
|
236
|
+
with self.profile("data loading"):
|
|
237
|
+
if use_entropy_loss:
|
|
238
|
+
batch = next(
|
|
239
|
+
self.get_shuffled_test_loader_iter(task, batch_size=16)
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
batch = next(
|
|
243
|
+
self.get_shuffled_train_loader_iter(task, batch_size=16)
|
|
244
|
+
)
|
|
245
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
246
|
+
images = batch[0]
|
|
247
|
+
with self.profile("forward pass"):
|
|
248
|
+
logits = self.compute_logits(module, images, task)
|
|
249
|
+
if use_entropy_loss:
|
|
250
|
+
loss = entropy_loss(logits)
|
|
251
|
+
else:
|
|
252
|
+
loss = loss_fn(logits, batch[1])
|
|
253
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
254
|
+
|
|
255
|
+
optimizer.zero_grad()
|
|
256
|
+
with self.profile("compute grad"):
|
|
257
|
+
self.fabric.backward(total_loss)
|
|
258
|
+
|
|
259
|
+
with self.profile("base optimizer step"):
|
|
260
|
+
optimizer.step()
|
|
261
|
+
|
|
262
|
+
metrics.update({"train/loss": loss.item()})
|
|
263
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
264
|
+
pbar.set_postfix(metrics)
|
|
265
|
+
return module
|
|
266
|
+
|
|
267
|
+
def frank_wolfe_iteration(self, merged_model, task):
|
|
268
|
+
|
|
269
|
+
merged_model.train()
|
|
270
|
+
# zero the gradients
|
|
271
|
+
requires_grad_dict = {}
|
|
272
|
+
for name, param in merged_model.named_parameters():
|
|
273
|
+
requires_grad_dict[name] = param.requires_grad
|
|
274
|
+
param.requires_grad = True
|
|
275
|
+
param.grad = None
|
|
276
|
+
|
|
277
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
278
|
+
avg_loss = defaultdict(list)
|
|
279
|
+
log.info(f"Processing task {task}")
|
|
280
|
+
for i in range(self.dataset_size):
|
|
281
|
+
with self.profile("data loading"):
|
|
282
|
+
batch = next(self.get_shuffled_train_loader_iter(task))
|
|
283
|
+
with self.profile("forward pass"):
|
|
284
|
+
logits = self.compute_logits(merged_model, batch[0], task)
|
|
285
|
+
loss = loss_fn(logits, batch[1]) / (
|
|
286
|
+
self.dataset_size * len(self.modelpool.model_names)
|
|
287
|
+
)
|
|
288
|
+
with self.profile("backward pass"):
|
|
289
|
+
loss.backward()
|
|
290
|
+
avg_loss[task].append(loss.item())
|
|
291
|
+
|
|
292
|
+
# calculate the loss
|
|
293
|
+
avg_loss = {
|
|
294
|
+
task: sum(losses) / len(losses) for task, losses in avg_loss.items()
|
|
295
|
+
}
|
|
296
|
+
log.info(
|
|
297
|
+
f"Average Loss: {avg_loss}, Total Loss: {sum(avg_loss.values()) / len(avg_loss)}"
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
gradients = {
|
|
301
|
+
name: param.grad.clone().to("cpu")
|
|
302
|
+
for name, param in merged_model.named_parameters()
|
|
303
|
+
if param.requires_grad
|
|
304
|
+
}
|
|
305
|
+
for name, param in merged_model.named_parameters():
|
|
306
|
+
param.requires_grad = requires_grad_dict[name]
|
|
307
|
+
param.grad = None
|
|
308
|
+
merged_model.eval()
|
|
309
|
+
|
|
310
|
+
return gradients
|
|
311
|
+
|
|
312
|
+
def frank_wolfe_selection(
|
|
313
|
+
self, gradients, checkpoints, model_to_merge_names=[], type="task"
|
|
314
|
+
):
|
|
315
|
+
assert type in [
|
|
316
|
+
"task",
|
|
317
|
+
"layer",
|
|
318
|
+
], f"Unsupported FW selection type: {type}, supported types are ['task', 'layer']"
|
|
319
|
+
min_inner_product = float("inf")
|
|
320
|
+
min_model = None
|
|
321
|
+
min_model_name = None
|
|
322
|
+
log_dict = {}
|
|
323
|
+
if type == "task":
|
|
324
|
+
for model_name, model_to_merge in checkpoints.items():
|
|
325
|
+
model_to_merge = model_to_merge.to("cpu").state_dict()
|
|
326
|
+
inner_product_sum = 0
|
|
327
|
+
for param_name, param_value in model_to_merge.items():
|
|
328
|
+
# caclulate consine similarity
|
|
329
|
+
grad = gradients[param_name]
|
|
330
|
+
ckpt = model_to_merge[param_name]
|
|
331
|
+
param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
|
|
332
|
+
torch.norm(grad) * torch.norm(ckpt)
|
|
333
|
+
)
|
|
334
|
+
inner_product_sum += param_alignment
|
|
335
|
+
log_dict[model_name] = inner_product_sum.item()
|
|
336
|
+
if (
|
|
337
|
+
inner_product_sum < min_inner_product
|
|
338
|
+
and model_name not in model_to_merge_names
|
|
339
|
+
):
|
|
340
|
+
min_inner_product = inner_product_sum
|
|
341
|
+
min_model = deepcopy(model_to_merge)
|
|
342
|
+
min_model_name = model_name
|
|
343
|
+
else:
|
|
344
|
+
min_model = {}
|
|
345
|
+
min_inner_product = {}
|
|
346
|
+
min_idx = {}
|
|
347
|
+
min_model_name = {}
|
|
348
|
+
for model_name, model_to_merge in checkpoints.items():
|
|
349
|
+
model_to_merge = model_to_merge.to("cpu").state_dict()
|
|
350
|
+
for param_name, param_value in model_to_merge.items():
|
|
351
|
+
# caclulate consine similarity
|
|
352
|
+
grad = gradients[param_name]
|
|
353
|
+
ckpt = model_to_merge[param_name]
|
|
354
|
+
param_alignment = torch.dot(grad.flatten(), ckpt.flatten()) / (
|
|
355
|
+
torch.norm(grad) * torch.norm(ckpt)
|
|
356
|
+
)
|
|
357
|
+
if (
|
|
358
|
+
param_name not in min_inner_product
|
|
359
|
+
or param_alignment < min_inner_product[param_name]
|
|
360
|
+
) and model_name not in model_to_merge_names[param_name]:
|
|
361
|
+
min_inner_product[param_name] = param_alignment
|
|
362
|
+
min_model[param_name] = param_value
|
|
363
|
+
min_idx[param_name] = model_name
|
|
364
|
+
min_model_name[param_name] = model_name
|
|
365
|
+
min_inner_product = sum(min_inner_product.values())
|
|
366
|
+
log_dict = {model_name: 0 for model_name in checkpoints.keys()}
|
|
367
|
+
for k in min_idx.values():
|
|
368
|
+
log_dict[k] += 1
|
|
369
|
+
|
|
370
|
+
return min_model, min_model_name, min_inner_product, log_dict
|
|
371
|
+
|
|
372
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
373
|
+
log.info("Fusing models using FW merging.")
|
|
374
|
+
self.modelpool = modelpool
|
|
375
|
+
tasks = self.tasks if self.tasks else self.modelpool.model_names
|
|
376
|
+
self.log_hyperparams(self.config)
|
|
377
|
+
self.on_frank_wolfe_iteration_start()
|
|
378
|
+
|
|
379
|
+
assert modelpool.has_pretrained, "Pretrained model is required."
|
|
380
|
+
finetuned_models = {
|
|
381
|
+
name: modelpool.load_model(name)
|
|
382
|
+
for name in modelpool.model_names[: self.max_num_models]
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
if self.init_weight == "base" or self.init_weight == "":
|
|
386
|
+
merged_model = modelpool.load_model("_pretrained_")
|
|
387
|
+
else:
|
|
388
|
+
log.info("Initializing the merged model with the initial weight")
|
|
389
|
+
if isinstance(self.init_weight, str):
|
|
390
|
+
# self.config.weights is a path to a saved tensor
|
|
391
|
+
layer_wise_weight = load_tensor_from_file(self.init_weight)
|
|
392
|
+
else:
|
|
393
|
+
raise ValueError(f"Unsupported weights format: {self.init_weight}")
|
|
394
|
+
|
|
395
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
396
|
+
layerwise_merged_model = LayerWiseMergedModel(
|
|
397
|
+
layer_wise_weight=layer_wise_weight,
|
|
398
|
+
pretrained_model=pretrained_model,
|
|
399
|
+
finetuned_models=list(finetuned_models.values())[: self.max_num_models],
|
|
400
|
+
clamp_weights=False,
|
|
401
|
+
tie_weights=True,
|
|
402
|
+
strict=False,
|
|
403
|
+
).cuda()
|
|
404
|
+
merged_model = layerwise_merged_model.merge_and_unload()
|
|
405
|
+
|
|
406
|
+
initial_model = modelpool.load_model("_pretrained_")
|
|
407
|
+
self.set_requires_grad(merged_model, initial_model)
|
|
408
|
+
# initial_model.load_state_dict(deepcopy(merged_model.state_dict()))
|
|
409
|
+
# finetuned_models['initial'] = initial_model
|
|
410
|
+
for step_idx in (
|
|
411
|
+
pbar := tqdm(
|
|
412
|
+
range(self.max_iters if not self.is_debug_mode else 1),
|
|
413
|
+
("[DEBUG MODE] " if self.is_debug_mode else "") + "Frank-Wolfe Merging",
|
|
414
|
+
dynamic_ncols=True,
|
|
415
|
+
)
|
|
416
|
+
):
|
|
417
|
+
# Find the task vector with the most alignment to the gradient
|
|
418
|
+
models_dict_to_merge = []
|
|
419
|
+
model_to_merge_names = (
|
|
420
|
+
[]
|
|
421
|
+
if self.granularity == "task"
|
|
422
|
+
else {name: [] for name in merged_model.state_dict().keys()}
|
|
423
|
+
)
|
|
424
|
+
inner_products = []
|
|
425
|
+
for task in tasks:
|
|
426
|
+
torch.set_grad_enabled(True)
|
|
427
|
+
torch.cuda.empty_cache()
|
|
428
|
+
gradients = self.frank_wolfe_iteration(merged_model.cuda(), task)
|
|
429
|
+
torch.set_grad_enabled(False)
|
|
430
|
+
grad_norm = torch.norm(
|
|
431
|
+
torch.stack([torch.norm(g) for g in gradients.values()])
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
min_model, min_model_name, min_inner_product, log_dict = (
|
|
435
|
+
self.frank_wolfe_selection(
|
|
436
|
+
gradients,
|
|
437
|
+
finetuned_models,
|
|
438
|
+
model_to_merge_names,
|
|
439
|
+
type=self.granularity,
|
|
440
|
+
)
|
|
441
|
+
)
|
|
442
|
+
if self.granularity == "task":
|
|
443
|
+
model_to_merge_names.append(min_model_name)
|
|
444
|
+
else:
|
|
445
|
+
for k, v in min_model_name.items():
|
|
446
|
+
model_to_merge_names[k].append(v)
|
|
447
|
+
models_dict_to_merge.append(min_model)
|
|
448
|
+
inner_products.append(min_inner_product)
|
|
449
|
+
|
|
450
|
+
log.info(f"Task: {task}, Inner Products: {log_dict}")
|
|
451
|
+
if (
|
|
452
|
+
len(models_dict_to_merge) >= len(self.modelpool.model_names)
|
|
453
|
+
or len(models_dict_to_merge) >= self.max_num_models
|
|
454
|
+
):
|
|
455
|
+
log.info(f"Breaking at {len(models_dict_to_merge)}")
|
|
456
|
+
break
|
|
457
|
+
|
|
458
|
+
# print iteration information
|
|
459
|
+
log.info(
|
|
460
|
+
f"Iteration {step_idx+1}, Task Vector: {model_to_merge_names}, Gradient Norm: {grad_norm:.6f}, Inner Products: {inner_products}"
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
if self.merge_fn == "adamerging":
|
|
464
|
+
models_to_merge = [
|
|
465
|
+
modelpool.load_model("_pretrained_")
|
|
466
|
+
for _ in range(len(models_dict_to_merge))
|
|
467
|
+
]
|
|
468
|
+
layer_wise_weight = get_layer_wise_weights(
|
|
469
|
+
num_models=len(models_to_merge),
|
|
470
|
+
num_layers=len(
|
|
471
|
+
tuple(
|
|
472
|
+
filter(
|
|
473
|
+
lambda p: p.requires_grad,
|
|
474
|
+
models_to_merge[0].parameters(),
|
|
475
|
+
)
|
|
476
|
+
)
|
|
477
|
+
),
|
|
478
|
+
init_values=self.ada_coeff if step_idx > 0 else 0.3,
|
|
479
|
+
)
|
|
480
|
+
for model_to_merge, model_to_merge_dict in zip(
|
|
481
|
+
models_to_merge, models_dict_to_merge
|
|
482
|
+
):
|
|
483
|
+
model_to_merge.load_state_dict(model_to_merge_dict)
|
|
484
|
+
layerwise_merged_model = LayerWiseMergedModel(
|
|
485
|
+
layer_wise_weight=layer_wise_weight,
|
|
486
|
+
pretrained_model=merged_model.to("cpu"),
|
|
487
|
+
finetuned_models=models_to_merge,
|
|
488
|
+
clamp_weights=False,
|
|
489
|
+
tie_weights=True,
|
|
490
|
+
strict=False,
|
|
491
|
+
).cuda()
|
|
492
|
+
torch.set_grad_enabled(True)
|
|
493
|
+
layerwise_merged_model = self.run_adamerging(layerwise_merged_model)
|
|
494
|
+
torch.set_grad_enabled(False)
|
|
495
|
+
with torch.no_grad():
|
|
496
|
+
merged_model = layerwise_merged_model.merge_and_unload()
|
|
497
|
+
self.set_requires_grad(merged_model, initial_model)
|
|
498
|
+
del (
|
|
499
|
+
models_to_merge,
|
|
500
|
+
layerwise_merged_model,
|
|
501
|
+
layer_wise_weight,
|
|
502
|
+
models_dict_to_merge,
|
|
503
|
+
)
|
|
504
|
+
else:
|
|
505
|
+
step = 2 / (step_idx + 2) * self.step_size if step_idx > 0 else 1
|
|
506
|
+
merged_model = task_arithmetic_merge(
|
|
507
|
+
merged_model.to("cpu"), models_dict_to_merge, 0.3 * step
|
|
508
|
+
)
|
|
509
|
+
del models_dict_to_merge
|
|
510
|
+
|
|
511
|
+
torch.set_grad_enabled(False)
|
|
512
|
+
merged_model = merged_model.cuda().eval()
|
|
513
|
+
return merged_model
|
|
514
|
+
|
|
515
|
+
def set_requires_grad(self, merged_model, initial_model):
|
|
516
|
+
for name, param in initial_model.named_parameters():
|
|
517
|
+
for n, p in merged_model.named_parameters():
|
|
518
|
+
if name == n:
|
|
519
|
+
p.requires_grad = param.requires_grad
|