fusion-bench 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
- fusion_bench/method/DOGE_TA/__init__.py +2 -0
- fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
- fusion_bench/method/__init__.py +22 -0
- fusion_bench/method/classification/continual_clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +30 -13
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -20,12 +20,17 @@ class AlgorithmFactory:
|
|
|
20
20
|
# model merging methods
|
|
21
21
|
"clip_task_wise_adamerging": ".adamerging.clip_task_wise_adamerging.CLIPTaskWiseAdaMergingAlgorithm",
|
|
22
22
|
"clip_layer_wise_adamerging": ".adamerging.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
|
|
23
|
+
"clip_layer_wise_adamerging_doge_ta": ".DOGE_TA.clip_layer_wise_adamerging.CLIPLayerWiseAdaMergingAlgorithm",
|
|
23
24
|
"singular_projection_merging": "fusion_bench.method.smile_upscaling.singular_projection_merging.SingularProjectionMergingAlgorithm",
|
|
24
25
|
"clip_layer_wise_adamerging_surgery": ".surgery.clip_layer_wise_adamerging_surgery.CLIPLayerWiseAdaMergingSurgeryAlgorithm",
|
|
25
26
|
# plug-and-play model merging methods
|
|
26
27
|
"clip_concrete_task_arithmetic": ".concrete_subspace.clip_concrete_task_arithmetic.ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
27
28
|
"clip_concrete_task_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteTaskWiseAdaMergingForCLIP",
|
|
28
29
|
"clip_concrete_layer_wise_adamerging": ".concrete_subspace.clip_concrete_adamerging.ConcreteLayerWiseAdaMergingForCLIP",
|
|
30
|
+
"clip_post_defense_AWM": ".concrete_subspace.clip_post_defense.PostDefenseAWMAlgorithmForCLIP",
|
|
31
|
+
"clip_post_defense_SAU": ".concrete_subspace.clip_post_defense.PostDefenseSAUAlgorithmForCLIP",
|
|
32
|
+
"clip_safe_concrete_layer_wise_adamerging": ".concrete_subspace.clip_safe_concrete_adamerging.ConcreteSafeLayerWiseAdaMergingForCLIP",
|
|
33
|
+
"clip_safe_concrete_task_wise_adamerging": ".concrete_subspace.clip_safe_concrete_adamerging.ConcreteSafeTaskWiseAdaMergingForCLIP",
|
|
29
34
|
# model mixing methods
|
|
30
35
|
"clip_weight_ensembling_moe": ".we_moe.clip_we_moe.CLIPWeightEnsemblingMoEAlgorithm",
|
|
31
36
|
"sparse_clip_weight_ensembling_moe": "fusion_bench.method.SparseCLIPWeightEnsemblingMoEAlgorithm",
|
fusion_bench/dataset/fer2013.py
CHANGED
|
@@ -0,0 +1,364 @@
|
|
|
1
|
+
R"""
|
|
2
|
+
This script contains the general implementation of Modeling Multi-Task Model Merging as Adaptive Projective Gradient Descent.
|
|
3
|
+
|
|
4
|
+
https://arxiv.org/abs/2501.01230
|
|
5
|
+
|
|
6
|
+
Example Usage:
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
fusion_bench \
|
|
10
|
+
method=DOGE_TA/DOGE_TA \
|
|
11
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
|
|
12
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
13
|
+
|
|
14
|
+
fusion_bench \
|
|
15
|
+
method=adamerging \
|
|
16
|
+
method.name=clip_layer_wise_adamerging_doge_ta \
|
|
17
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
|
|
18
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
19
|
+
```
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import copy
|
|
23
|
+
import logging
|
|
24
|
+
import time
|
|
25
|
+
from collections import OrderedDict
|
|
26
|
+
from copy import deepcopy
|
|
27
|
+
from functools import reduce
|
|
28
|
+
from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
|
|
29
|
+
|
|
30
|
+
import lightning as L
|
|
31
|
+
import torch
|
|
32
|
+
from torch import nn
|
|
33
|
+
|
|
34
|
+
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
35
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
36
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
37
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
38
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
39
|
+
state_dict_add,
|
|
40
|
+
state_dict_mul,
|
|
41
|
+
state_dict_sub,
|
|
42
|
+
)
|
|
43
|
+
from fusion_bench.utils.type import StateDictType
|
|
44
|
+
|
|
45
|
+
log = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DOGE_TA_Algorithm(
|
|
49
|
+
BaseAlgorithm,
|
|
50
|
+
SimpleProfilerMixin,
|
|
51
|
+
LightningFabricMixin,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Task Arithmetic Algorithm for model fusion with learnable delta.
|
|
55
|
+
|
|
56
|
+
This class extends the Task Arithmetic method to include a learnable delta
|
|
57
|
+
for task vectors, optimized to maximize cosine similarity among the task vectors.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
61
|
+
delta (StateDictType): A learnable parameter to adjust task vectors, initialized as zeros.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
65
|
+
"subspace": "subspace",
|
|
66
|
+
"K": "K",
|
|
67
|
+
"lamda": "lamda",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
def __init__(self, subspace, K, lamda):
|
|
71
|
+
self.delta = None # Initialize delta as None; will be set during run
|
|
72
|
+
self.subspace = subspace
|
|
73
|
+
self.K = K
|
|
74
|
+
self.lamda = lamda
|
|
75
|
+
super().__init__()
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def device(self) -> torch.device:
|
|
79
|
+
return self.fabric.device
|
|
80
|
+
|
|
81
|
+
@torch.no_grad()
|
|
82
|
+
def compute_task_vectors(
|
|
83
|
+
self, modelpool: BaseModelPool, pretrained_model: nn.Module
|
|
84
|
+
) -> List[StateDictType]:
|
|
85
|
+
"""
|
|
86
|
+
Computes task vectors for each model in the model pool relative to the pretrained model.
|
|
87
|
+
"""
|
|
88
|
+
task_vectors = []
|
|
89
|
+
pretrained_sd = pretrained_model.state_dict(keep_vars=True)
|
|
90
|
+
filtered_keys = [
|
|
91
|
+
k
|
|
92
|
+
for k in pretrained_sd.keys()
|
|
93
|
+
if ("encoder" in k and "layer_norm" not in k and "weight" in k)
|
|
94
|
+
] # Flan T5: "layer_norm" not in k and ("q.weight" in k or "v.weight" in k)
|
|
95
|
+
|
|
96
|
+
for model_name in modelpool.model_names:
|
|
97
|
+
model = modelpool.load_model(model_name)
|
|
98
|
+
model_sd = model.state_dict(keep_vars=True)
|
|
99
|
+
|
|
100
|
+
filtered_task_vector = {
|
|
101
|
+
k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
|
|
102
|
+
}
|
|
103
|
+
task_vectors.append(filtered_task_vector)
|
|
104
|
+
|
|
105
|
+
return task_vectors
|
|
106
|
+
|
|
107
|
+
def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
|
|
108
|
+
"""
|
|
109
|
+
Computes the loss based on delta and task vectors for a specific layer.
|
|
110
|
+
"""
|
|
111
|
+
total_loss = 0.0
|
|
112
|
+
|
|
113
|
+
layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
|
|
114
|
+
sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
|
|
115
|
+
|
|
116
|
+
layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
|
|
117
|
+
sum_over_delta = layer_delta_scale.sum(dim=0)
|
|
118
|
+
|
|
119
|
+
# Iterate through each vector and calculate the loss one by one
|
|
120
|
+
for v_j in layer_vectors:
|
|
121
|
+
part1 = -v_j * sum_over_num_vectors
|
|
122
|
+
part2 = -v_j * sum_over_delta
|
|
123
|
+
part3 = v_j * v_j
|
|
124
|
+
|
|
125
|
+
expression = part1 + part2 + part3
|
|
126
|
+
layer_loss = expression.sum(dim=1).pow(2).sum()
|
|
127
|
+
|
|
128
|
+
# Cumulative total loss
|
|
129
|
+
total_loss += layer_loss
|
|
130
|
+
return total_loss
|
|
131
|
+
|
|
132
|
+
@torch.enable_grad()
|
|
133
|
+
def optimize_delta(self, task_vectors: List[StateDictType]) -> None:
|
|
134
|
+
"""
|
|
135
|
+
Optimizes the delta based on the loss of task vectors.
|
|
136
|
+
"""
|
|
137
|
+
if self.delta is None:
|
|
138
|
+
self.delta = {
|
|
139
|
+
k: nn.Parameter(torch.zeros_like(v, device=self.device).detach())
|
|
140
|
+
for k, v in task_vectors[0].items()
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
optimizer = torch.optim.Adam(self.delta.values(), lr=1e-4)
|
|
144
|
+
initial_mem = torch.cuda.memory_allocated()
|
|
145
|
+
start_time = time.time()
|
|
146
|
+
for layer_name in task_vectors[0].keys():
|
|
147
|
+
layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors]).to(
|
|
148
|
+
self.device
|
|
149
|
+
)
|
|
150
|
+
layer_lamdas = torch.stack(
|
|
151
|
+
[lamdas[layer_name] for lamdas in self.lamdas]
|
|
152
|
+
).to(self.device)
|
|
153
|
+
for _ in range(400):
|
|
154
|
+
optimizer.zero_grad()
|
|
155
|
+
loss = self.taskvector_loss(
|
|
156
|
+
layer_vectors, self.delta[layer_name], layer_lamdas
|
|
157
|
+
)
|
|
158
|
+
self.fabric.backward(loss)
|
|
159
|
+
grad_proj = (
|
|
160
|
+
self.projection[layer_name] @ self.delta[layer_name].grad.detach()
|
|
161
|
+
)
|
|
162
|
+
self.delta[layer_name].grad.data = self.delta[
|
|
163
|
+
layer_name
|
|
164
|
+
].grad.data.sub_(grad_proj)
|
|
165
|
+
optimizer.step()
|
|
166
|
+
self.delta[layer_name].grad = None
|
|
167
|
+
end_time = time.time()
|
|
168
|
+
print(f"Running time: {end_time - start_time} s")
|
|
169
|
+
final_mem = torch.cuda.memory_allocated()
|
|
170
|
+
print(f"Memory usage: {(final_mem - initial_mem) / (1024 ** 2)} MB")
|
|
171
|
+
print("Optimization completed.")
|
|
172
|
+
|
|
173
|
+
@torch.no_grad()
|
|
174
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
175
|
+
"""
|
|
176
|
+
Runs the Algorithm with learnable delta to fuse models in the given model pool.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
nn.Module: The pre-trained model with the merged task vectors after optimizing delta.
|
|
183
|
+
"""
|
|
184
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
185
|
+
modelpool = BaseModelPool(modelpool)
|
|
186
|
+
|
|
187
|
+
log.info("Fusing models using DOGE_TA with learnable delta.")
|
|
188
|
+
with self.profile("load model"):
|
|
189
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
190
|
+
|
|
191
|
+
task_vectors = self.compute_task_vectors(modelpool, pretrained_model)
|
|
192
|
+
|
|
193
|
+
self.lamdas = self.compute_layer_lamdas(task_vectors)
|
|
194
|
+
self.projection = {}
|
|
195
|
+
for layer_name in task_vectors[0].keys():
|
|
196
|
+
for i, vector in enumerate(task_vectors):
|
|
197
|
+
layer_vector = vector[layer_name].to(self.device)
|
|
198
|
+
u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
|
|
199
|
+
if i == 0:
|
|
200
|
+
print(f"Computed SVD for {layer_name}...")
|
|
201
|
+
sum_u = torch.zeros_like(u, device=layer_vector.device)
|
|
202
|
+
sum_s = torch.zeros_like(s, device=layer_vector.device)
|
|
203
|
+
sum_v = torch.zeros_like(v, device=layer_vector.device)
|
|
204
|
+
|
|
205
|
+
reduced_index_s = int(s.shape[0] / len(task_vectors))
|
|
206
|
+
|
|
207
|
+
# select only the first reduced_index_s columns of u and place them
|
|
208
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
209
|
+
:, :reduced_index_s
|
|
210
|
+
]
|
|
211
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
212
|
+
:reduced_index_s
|
|
213
|
+
]
|
|
214
|
+
# select only the first reduced_index_s rows of v and place them
|
|
215
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
216
|
+
:reduced_index_s, :
|
|
217
|
+
]
|
|
218
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
219
|
+
layer_proj = torch.matmul(
|
|
220
|
+
u_u[:, : int(s.shape[0] / self.config.subspace)],
|
|
221
|
+
u_u[:, : int(s.shape[0] / self.config.subspace)].T,
|
|
222
|
+
)
|
|
223
|
+
self.projection[layer_name] = layer_proj
|
|
224
|
+
|
|
225
|
+
self.optimize_delta(task_vectors)
|
|
226
|
+
|
|
227
|
+
del self.projection
|
|
228
|
+
self.delta = {key: param.detach().cpu() for key, param in self.delta.items()}
|
|
229
|
+
self.lamdas = [
|
|
230
|
+
{key: param.cpu() for key, param in lamdas.items()}
|
|
231
|
+
for lamdas in self.lamdas
|
|
232
|
+
]
|
|
233
|
+
task_vectors = [
|
|
234
|
+
{k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
|
|
235
|
+
]
|
|
236
|
+
flat_vectors = []
|
|
237
|
+
vector_masks = []
|
|
238
|
+
for idx, task_vector in enumerate(task_vectors):
|
|
239
|
+
flat_vector = self.state_dict_to_vector(task_vector)
|
|
240
|
+
vector_mask = self.topk_values_mask(flat_vector, K=self.config.K)
|
|
241
|
+
flat_vectors.append(flat_vector)
|
|
242
|
+
vector_masks.append(vector_mask)
|
|
243
|
+
flat_delta = self.state_dict_to_vector(self.delta)
|
|
244
|
+
|
|
245
|
+
adjusted_vectors = [
|
|
246
|
+
self.vector_to_state_dict(
|
|
247
|
+
(flat_vector + flat_delta) * vector_mask, self.delta
|
|
248
|
+
)
|
|
249
|
+
for flat_vector, vector_mask in zip(flat_vectors, vector_masks)
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
for layer_name in adjusted_vectors[0].keys():
|
|
253
|
+
layer_vectors = torch.stack(
|
|
254
|
+
[vec[layer_name] for vec in adjusted_vectors], dim=0
|
|
255
|
+
)
|
|
256
|
+
layer_lamdas = torch.stack(
|
|
257
|
+
[lamdas[layer_name] for lamdas in self.lamdas], dim=0
|
|
258
|
+
)
|
|
259
|
+
layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
|
|
260
|
+
task_vectors[0][layer_name] = layer_vectors_scale.sum(dim=0)
|
|
261
|
+
|
|
262
|
+
final_state_dict = {}
|
|
263
|
+
pretrained_sd = pretrained_model.state_dict(keep_vars=True)
|
|
264
|
+
for k, v in pretrained_sd.items():
|
|
265
|
+
if k in task_vectors[0]:
|
|
266
|
+
final_state_dict[k] = v + task_vectors[0][k]
|
|
267
|
+
else:
|
|
268
|
+
final_state_dict[k] = v
|
|
269
|
+
|
|
270
|
+
pretrained_model.load_state_dict(final_state_dict)
|
|
271
|
+
|
|
272
|
+
self.print_profile_summary()
|
|
273
|
+
return pretrained_model
|
|
274
|
+
|
|
275
|
+
def compute_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
|
|
276
|
+
lamdas = []
|
|
277
|
+
for vec in vectors:
|
|
278
|
+
norm_vec = torch.norm(
|
|
279
|
+
torch.cat([param.flatten() for param in vec.values()])
|
|
280
|
+
)
|
|
281
|
+
# norm_vec = sum([torch.norm(param) for param in vec.values()])
|
|
282
|
+
lamdas.append(self.config.lamda / norm_vec)
|
|
283
|
+
print(lamdas)
|
|
284
|
+
return lamdas
|
|
285
|
+
|
|
286
|
+
def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
|
|
287
|
+
lamdas = []
|
|
288
|
+
for vec in vectors:
|
|
289
|
+
tmp = {}
|
|
290
|
+
for layer_name in vec.keys():
|
|
291
|
+
norm_vec = torch.norm(vec[layer_name])
|
|
292
|
+
tmp[layer_name] = self.config.lamda / norm_vec
|
|
293
|
+
lamdas.append(tmp)
|
|
294
|
+
return lamdas
|
|
295
|
+
|
|
296
|
+
def topk_values_mask(self, M, K):
|
|
297
|
+
if K > 1:
|
|
298
|
+
K /= 100
|
|
299
|
+
|
|
300
|
+
original_shape = M.shape
|
|
301
|
+
if M.dim() == 1:
|
|
302
|
+
M = M.unsqueeze(0)
|
|
303
|
+
|
|
304
|
+
n, d = M.shape
|
|
305
|
+
k = int(d * K)
|
|
306
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
307
|
+
|
|
308
|
+
# Find the k-th smallest element by magnitude for each row
|
|
309
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
310
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
311
|
+
mask = M.abs() >= kth_values
|
|
312
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
313
|
+
|
|
314
|
+
return final_mask
|
|
315
|
+
|
|
316
|
+
def state_dict_to_vector(self, state_dict, remove_keys=[]):
|
|
317
|
+
"""
|
|
318
|
+
Convert a state dictionary to a vector, removing specified keys.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
state_dict (dict): The state dictionary to convert.
|
|
322
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
Tensor: A vector representation of the state dictionary.
|
|
326
|
+
"""
|
|
327
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
328
|
+
for key in remove_keys:
|
|
329
|
+
if key in shared_state_dict:
|
|
330
|
+
del shared_state_dict[key]
|
|
331
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
332
|
+
return nn.utils.parameters_to_vector(
|
|
333
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
|
|
337
|
+
"""
|
|
338
|
+
Convert a vector back to a state dictionary, removing specified keys.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
vector (Tensor): The vector to convert.
|
|
342
|
+
state_dict (dict): The reference state dictionary.
|
|
343
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
dict: A state dictionary representation of the vector.
|
|
347
|
+
"""
|
|
348
|
+
# create a reference dict to define the order of the vector
|
|
349
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
350
|
+
for key in remove_keys:
|
|
351
|
+
if key in reference_dict:
|
|
352
|
+
del reference_dict[key]
|
|
353
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
354
|
+
|
|
355
|
+
# create a shared state dict using the reference dict
|
|
356
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
357
|
+
|
|
358
|
+
# add back the encoder and decoder embedding weights.
|
|
359
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
360
|
+
for key in remove_keys:
|
|
361
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
362
|
+
"transformer.shared.weight"
|
|
363
|
+
]
|
|
364
|
+
return sorted_reference_dict
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Example Usage:
|
|
3
|
+
|
|
4
|
+
```bash
|
|
5
|
+
fusion_bench \
|
|
6
|
+
method=adamerging \
|
|
7
|
+
method.name=clip_layer_wise_adamerging \
|
|
8
|
+
method.save_merging_weights=merging_weights.pt \
|
|
9
|
+
modelpool=clip-vit-base-patch32_TA8 \
|
|
10
|
+
taskpool=clip-vit-classification_TA8 \
|
|
11
|
+
fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
|
|
12
|
+
fabric.loggers.name=clip_layer_wise_adamerging_adam
|
|
13
|
+
```
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
|
|
21
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
22
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
23
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
24
|
+
|
|
25
|
+
from .layer_wise_adamerging import LayerWiseAdaMergingAlgorithm
|
|
26
|
+
|
|
27
|
+
log = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CLIPLayerWiseAdaMergingAlgorithm(
|
|
31
|
+
CLIPClassificationMixin,
|
|
32
|
+
LayerWiseAdaMergingAlgorithm,
|
|
33
|
+
):
|
|
34
|
+
def on_test_time_adaptation_start(self):
|
|
35
|
+
"""
|
|
36
|
+
Here we load the CLIP processor and construct the zero-shot classification head for each task.
|
|
37
|
+
"""
|
|
38
|
+
self.setup_zero_shot_classification_head()
|
|
39
|
+
|
|
40
|
+
@functools.cache
|
|
41
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
42
|
+
return super().get_shuffled_test_loader_iter(
|
|
43
|
+
task,
|
|
44
|
+
batch_size=self.config.batch_size,
|
|
45
|
+
num_workers=self.config.num_workers,
|
|
46
|
+
)
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast # noqa: F401
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
8
|
+
from omegaconf import DictConfig
|
|
9
|
+
from torch import Tensor, nn
|
|
10
|
+
from torch.utils.data import DataLoader
|
|
11
|
+
from tqdm.autonotebook import tqdm
|
|
12
|
+
|
|
13
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
14
|
+
from fusion_bench.compat.modelpool import ModelPool
|
|
15
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
16
|
+
from fusion_bench.method.adamerging.utils import get_memory_usage
|
|
17
|
+
from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.models.wrappers.layer_wise_fusion_doge_ta import (
|
|
20
|
+
LayerWiseMergedModel,
|
|
21
|
+
get_layer_wise_weights,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.utils.data import load_tensor_from_file
|
|
24
|
+
from fusion_bench.utils.type import TorchModelType
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
|
|
28
|
+
|
|
29
|
+
log = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class LayerWiseAdaMergingAlgorithm(
|
|
33
|
+
ModelFusionAlgorithm,
|
|
34
|
+
LightningFabricMixin,
|
|
35
|
+
SimpleProfilerMixin,
|
|
36
|
+
):
|
|
37
|
+
_program: "FabricModelFusionProgram"
|
|
38
|
+
"""The program that this algorithm is running on."""
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
Implements the Layer-Wise AdaMerging Algorithm.
|
|
42
|
+
|
|
43
|
+
This class merges the layers of a pretrained model with those of several fine-tuned models.
|
|
44
|
+
The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, algorithm_config: DictConfig):
|
|
48
|
+
"""
|
|
49
|
+
Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
algorithm_config (DictConfig): The configuration for the algorithm.
|
|
53
|
+
"""
|
|
54
|
+
super().__init__(algorithm_config)
|
|
55
|
+
|
|
56
|
+
def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
|
|
57
|
+
"""
|
|
58
|
+
Constructs a wrapped layer-wise merged model from model pool.
|
|
59
|
+
|
|
60
|
+
This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
|
|
61
|
+
The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
|
|
62
|
+
The merging weights can be initialized based on a provided configuration or loaded from a file.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
|
|
69
|
+
"""
|
|
70
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
71
|
+
finetuned_models = [
|
|
72
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
# initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
|
|
76
|
+
if self.config.weights is None:
|
|
77
|
+
layer_wise_weight = get_layer_wise_weights(
|
|
78
|
+
num_models=len(modelpool.model_names),
|
|
79
|
+
num_layers=len(
|
|
80
|
+
tuple(
|
|
81
|
+
filter(lambda p: p.requires_grad, pretrained_model.parameters())
|
|
82
|
+
)
|
|
83
|
+
),
|
|
84
|
+
init_values=self.config.init_values,
|
|
85
|
+
)
|
|
86
|
+
else:
|
|
87
|
+
if isinstance(self.config.weights, str):
|
|
88
|
+
# self.config.weights is a path to a saved tensor
|
|
89
|
+
layer_wise_weight = load_tensor_from_file(self.config.weights)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Unsupported weights format: {self.config.weights}")
|
|
92
|
+
|
|
93
|
+
module = LayerWiseMergedModel(
|
|
94
|
+
layer_wise_weight=layer_wise_weight,
|
|
95
|
+
pretrained_model=pretrained_model,
|
|
96
|
+
finetuned_models=finetuned_models,
|
|
97
|
+
clamp_weights=self.config.clamp_weights,
|
|
98
|
+
tie_weights=self.config.tie_weights,
|
|
99
|
+
strict=self.config.strict,
|
|
100
|
+
)
|
|
101
|
+
print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
|
|
102
|
+
return module
|
|
103
|
+
|
|
104
|
+
@rank_zero_only
|
|
105
|
+
def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
|
|
106
|
+
"""
|
|
107
|
+
Save the merging weights to a file.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
file_path (str): The path to save the merging weights.
|
|
111
|
+
merging_weights (torch.Tensor): The merging weights to save.
|
|
112
|
+
"""
|
|
113
|
+
if self.fabric.is_global_zero and self.config.get(
|
|
114
|
+
"save_merging_weights", False
|
|
115
|
+
):
|
|
116
|
+
if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
|
|
117
|
+
# if the file path is not absolute or relative to current working directory, save it in the log directory
|
|
118
|
+
save_path = os.path.join(self.log_dir, file_path)
|
|
119
|
+
else:
|
|
120
|
+
save_path = file_path
|
|
121
|
+
log.info(f"saving merging weights to {save_path}.")
|
|
122
|
+
if os.path.dirname(save_path):
|
|
123
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
124
|
+
torch.save(merging_weights.detach().cpu(), save_path)
|
|
125
|
+
|
|
126
|
+
def run(self, modelpool: ModelPool, **kwargs):
|
|
127
|
+
"""
|
|
128
|
+
Run the Layer-Wise AdaMerging Algorithm.
|
|
129
|
+
|
|
130
|
+
This method constructs the wrapped model and performs test-time adaptation if necessary.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
LayerWiseMergedModel: The merged model after test-time adaptation.
|
|
137
|
+
"""
|
|
138
|
+
log.info("Fusing models using layer-wise adaptive merging.")
|
|
139
|
+
self.modelpool = modelpool
|
|
140
|
+
self.log_hyperparams(self.config)
|
|
141
|
+
|
|
142
|
+
with self.profile("construct the wrapped model"):
|
|
143
|
+
module = self.construct_layer_wise_merged_model(modelpool)
|
|
144
|
+
|
|
145
|
+
if self.config.weights is not None:
|
|
146
|
+
# skip the test-time adaptation
|
|
147
|
+
return module.merge_and_unload()
|
|
148
|
+
else:
|
|
149
|
+
with self.profile("test-time adaptation"):
|
|
150
|
+
module = self.test_time_adaptation(module)
|
|
151
|
+
if self.config.get("save_merging_weights", False):
|
|
152
|
+
self.save_merging_weights(
|
|
153
|
+
self.config.save_merging_weights, module.merge_weight
|
|
154
|
+
)
|
|
155
|
+
return module.merge_and_unload()
|
|
156
|
+
|
|
157
|
+
def on_test_time_adaptation_start(self):
|
|
158
|
+
"""
|
|
159
|
+
Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
|
|
160
|
+
"""
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
|
|
165
|
+
"""
|
|
166
|
+
Loader of test dataset for test-time adaptation. labels are not needed.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
task (str): The name of the task.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
DataLoader: The data loader for the test dataset.
|
|
173
|
+
"""
|
|
174
|
+
pass
|
|
175
|
+
|
|
176
|
+
@abstractmethod
|
|
177
|
+
def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
|
|
178
|
+
"""
|
|
179
|
+
Compute the logits for the given images and task.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
module: The model module.
|
|
183
|
+
images (Tensor): The input images.
|
|
184
|
+
task (str): The name of the task.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
Tensor: The computed logits.
|
|
188
|
+
"""
|
|
189
|
+
pass
|
|
190
|
+
|
|
191
|
+
def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
|
|
192
|
+
"""
|
|
193
|
+
Perform test-time adaptation on the merged model.
|
|
194
|
+
|
|
195
|
+
This method adapts the merging weights during test-time to improve performance.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
module (LayerWiseMergedModel): The merged model.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
LayerWiseMergedModel: The adapted merged model.
|
|
202
|
+
"""
|
|
203
|
+
self.on_test_time_adaptation_start()
|
|
204
|
+
|
|
205
|
+
# configure optimizer
|
|
206
|
+
if self.config.optimizer == "adam":
|
|
207
|
+
optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
|
|
208
|
+
print(f"{optimizer=}")
|
|
209
|
+
module, optimizer = self.fabric.setup(module, optimizer)
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
212
|
+
|
|
213
|
+
module.train()
|
|
214
|
+
module.merge_weights()
|
|
215
|
+
for step_idx in (
|
|
216
|
+
pbar := tqdm(
|
|
217
|
+
range(self.config.max_steps if not self.is_debug_mode else 1),
|
|
218
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
219
|
+
+ "AdaMerging Test-time adaptation",
|
|
220
|
+
dynamic_ncols=True,
|
|
221
|
+
)
|
|
222
|
+
):
|
|
223
|
+
# default behavior for first-order optimizers
|
|
224
|
+
for task in self.modelpool.model_names:
|
|
225
|
+
with self.profile("data loading"):
|
|
226
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
227
|
+
with self.profile("forward pass"):
|
|
228
|
+
logits = self.compute_logits(module, batch[0], task)
|
|
229
|
+
loss = entropy_loss(logits)
|
|
230
|
+
with self.profile("backward pass"):
|
|
231
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
232
|
+
|
|
233
|
+
with self.profile("optimizer step"):
|
|
234
|
+
optimizer.step()
|
|
235
|
+
optimizer.zero_grad()
|
|
236
|
+
with self.profile("merging weights"):
|
|
237
|
+
module.merge_weights()
|
|
238
|
+
|
|
239
|
+
metrics = {
|
|
240
|
+
"train/loss": loss.item(),
|
|
241
|
+
"train/weight_max": module.merge_weight.max().item(),
|
|
242
|
+
"train/weight_min": module.merge_weight.min().item(),
|
|
243
|
+
"train/weight_mean": module.merge_weight.mean().item(),
|
|
244
|
+
}
|
|
245
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
246
|
+
pbar.set_postfix(metrics)
|
|
247
|
+
|
|
248
|
+
log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
|
|
249
|
+
self.print_profile_summary()
|
|
250
|
+
return module
|