fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +0 -1
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/doge_ta/doge_ta.py +364 -0
- fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/isotropic_merging/iso.py +2 -2
- fusion_bench/method/opcm/opcm.py +93 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/task_singular_vector/TSVM.py +3 -3
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
from torch.func import functional_call
|
|
12
|
+
|
|
13
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
14
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add
|
|
15
|
+
from fusion_bench.utils.type import StateDictType
|
|
16
|
+
|
|
17
|
+
from .layer_wise_fusion import fuse_weights, get_layer_wise_weights
|
|
18
|
+
|
|
19
|
+
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LayerWiseMergedModel(nn.Module):
|
|
25
|
+
_merged_state_dict: StateDictType = None
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
layer_wise_weight: Tensor,
|
|
30
|
+
pretrained_model: nn.Module,
|
|
31
|
+
finetuned_models: List[nn.Module],
|
|
32
|
+
clamp_weights: bool = True,
|
|
33
|
+
tie_weights: bool = False,
|
|
34
|
+
strict: bool = True,
|
|
35
|
+
sparsity_ratio: Optional[float] = None,
|
|
36
|
+
normalized_merging_weights: bool = False,
|
|
37
|
+
):
|
|
38
|
+
R"""
|
|
39
|
+
This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
|
|
43
|
+
pretrained_model (nn.Module): The pretrained model to merge the weights into.
|
|
44
|
+
finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
|
|
45
|
+
clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
|
|
46
|
+
tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False.
|
|
47
|
+
strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True.
|
|
48
|
+
sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.
|
|
49
|
+
normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
|
|
50
|
+
"""
|
|
51
|
+
super().__init__()
|
|
52
|
+
if torch.cuda.is_available():
|
|
53
|
+
self._fabric = L.Fabric(devices=1)
|
|
54
|
+
self._fabric.launch()
|
|
55
|
+
self.clamp_weights = clamp_weights
|
|
56
|
+
self.tie_weights = tie_weights
|
|
57
|
+
self.strict = strict
|
|
58
|
+
self.sparsity_ratio = sparsity_ratio
|
|
59
|
+
self.nromalized_merging_weights = normalized_merging_weights
|
|
60
|
+
|
|
61
|
+
pretrained_sd = pretrained_model.state_dict(keep_vars=True)
|
|
62
|
+
filtered_keys = [
|
|
63
|
+
k
|
|
64
|
+
for k in pretrained_sd.keys()
|
|
65
|
+
if ("encoder" in k and "layer_norm" not in k and "weight" in k)
|
|
66
|
+
]
|
|
67
|
+
self.merge_weight = nn.Parameter(
|
|
68
|
+
layer_wise_weight[:, : len(filtered_keys)], requires_grad=True
|
|
69
|
+
)
|
|
70
|
+
task_vectors = []
|
|
71
|
+
for m in finetuned_models:
|
|
72
|
+
m.requires_grad_(False)
|
|
73
|
+
self.pretrained_model = pretrained_model.requires_grad_(False)
|
|
74
|
+
for model in finetuned_models:
|
|
75
|
+
model_sd = model.state_dict(keep_vars=True)
|
|
76
|
+
filtered_task_vector = {
|
|
77
|
+
k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
|
|
78
|
+
}
|
|
79
|
+
if self._fabric is not None:
|
|
80
|
+
filtered_task_vector = self._fabric.to_device(filtered_task_vector)
|
|
81
|
+
task_vectors.append(filtered_task_vector)
|
|
82
|
+
|
|
83
|
+
self.projection = {}
|
|
84
|
+
for layer_name in task_vectors[0].keys():
|
|
85
|
+
for i, vector in enumerate(task_vectors):
|
|
86
|
+
layer_vector = vector[layer_name]
|
|
87
|
+
u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
|
|
88
|
+
if i == 0:
|
|
89
|
+
print(f"Computed SVD for {layer_name}...")
|
|
90
|
+
sum_u = torch.zeros_like(u, device=layer_vector.device)
|
|
91
|
+
sum_s = torch.zeros_like(s, device=layer_vector.device)
|
|
92
|
+
sum_v = torch.zeros_like(v, device=layer_vector.device)
|
|
93
|
+
|
|
94
|
+
reduced_index_s = int(s.shape[0] / len(task_vectors))
|
|
95
|
+
|
|
96
|
+
# select only the first reduced_index_s columns of u and place them
|
|
97
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
98
|
+
:, :reduced_index_s
|
|
99
|
+
]
|
|
100
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
101
|
+
:reduced_index_s
|
|
102
|
+
]
|
|
103
|
+
# select only the first reduced_index_s rows of v and place them
|
|
104
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
105
|
+
:reduced_index_s, :
|
|
106
|
+
]
|
|
107
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
108
|
+
# u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
109
|
+
layer_proj = torch.matmul(
|
|
110
|
+
u_u[:, : int(s.shape[0] / len(task_vectors))],
|
|
111
|
+
u_u[:, : int(s.shape[0] / len(task_vectors))].T,
|
|
112
|
+
)
|
|
113
|
+
self.projection[layer_name] = layer_proj
|
|
114
|
+
|
|
115
|
+
self.delta = [
|
|
116
|
+
{
|
|
117
|
+
k: torch.zeros_like(v).clone().requires_grad_()
|
|
118
|
+
for k, v in task_vector.items()
|
|
119
|
+
}
|
|
120
|
+
for task_vector in task_vectors
|
|
121
|
+
]
|
|
122
|
+
if self._fabric is not None:
|
|
123
|
+
self.delta = self._fabric.to_device(self.delta)
|
|
124
|
+
self.lamdas = self.compute_layer_lamdas(task_vectors)
|
|
125
|
+
|
|
126
|
+
for layer_name in task_vectors[0].keys():
|
|
127
|
+
optimizer = torch.optim.Adam(
|
|
128
|
+
[delta[layer_name] for delta in self.delta], lr=1e-4
|
|
129
|
+
)
|
|
130
|
+
layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors])
|
|
131
|
+
layer_lamdas = torch.stack([lamdas[layer_name] for lamdas in self.lamdas])
|
|
132
|
+
for _ in range(400):
|
|
133
|
+
optimizer.zero_grad()
|
|
134
|
+
layer_delta = torch.stack([de[layer_name] for de in self.delta])
|
|
135
|
+
loss = self.taskvector_loss(layer_vectors, layer_delta, layer_lamdas)
|
|
136
|
+
print(f"Epoch: {_}, Layer: {layer_name}, Loss: {loss.item()}")
|
|
137
|
+
self._fabric.backward(loss)
|
|
138
|
+
for delta in self.delta:
|
|
139
|
+
grad_proj = (
|
|
140
|
+
self.projection[layer_name] @ delta[layer_name].grad.detach()
|
|
141
|
+
)
|
|
142
|
+
delta[layer_name].grad.data = delta[layer_name].grad.data.sub_(
|
|
143
|
+
grad_proj
|
|
144
|
+
)
|
|
145
|
+
optimizer.step()
|
|
146
|
+
for delta in self.delta:
|
|
147
|
+
for param in delta.values():
|
|
148
|
+
param.grad = None
|
|
149
|
+
del self.projection
|
|
150
|
+
self.delta = [
|
|
151
|
+
{key: param.detach().cpu() for key, param in delta.items()}
|
|
152
|
+
for delta in self.delta
|
|
153
|
+
]
|
|
154
|
+
self.lamdas = [
|
|
155
|
+
{key: param.cpu() for key, param in lamdas.items()}
|
|
156
|
+
for lamdas in self.lamdas
|
|
157
|
+
]
|
|
158
|
+
task_vectors = [
|
|
159
|
+
{k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
|
|
160
|
+
]
|
|
161
|
+
flat_vectors = []
|
|
162
|
+
vector_masks = []
|
|
163
|
+
for idx, task_vector in enumerate(task_vectors):
|
|
164
|
+
flat_vector = self.state_dict_to_vector(task_vector)
|
|
165
|
+
vector_mask = self.topk_values_mask(flat_vector, K=30)
|
|
166
|
+
flat_vectors.append(flat_vector)
|
|
167
|
+
vector_masks.append(vector_mask)
|
|
168
|
+
flat_deltas = [self.state_dict_to_vector(delta) for delta in self.delta]
|
|
169
|
+
self.task_vectors = [
|
|
170
|
+
self.vector_to_state_dict(
|
|
171
|
+
(flat_vector + flat_delta) * vector_mask, self.delta[0]
|
|
172
|
+
)
|
|
173
|
+
for flat_vector, flat_delta, vector_mask in zip(
|
|
174
|
+
flat_vectors, flat_deltas, vector_masks
|
|
175
|
+
)
|
|
176
|
+
]
|
|
177
|
+
if self._fabric is not None:
|
|
178
|
+
self.task_vectors = self._fabric.to_device(self.task_vectors)
|
|
179
|
+
|
|
180
|
+
# if `sparisty_ratio` is given, pruning the task vectors.
|
|
181
|
+
if sparsity_ratio is not None:
|
|
182
|
+
from fusion_bench.method.pruning.prune_utils import (
|
|
183
|
+
unstructured_magnitude_prune_,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
for name, param in self.task_vectors.named_parameters():
|
|
187
|
+
if param.dim() != 2:
|
|
188
|
+
continue
|
|
189
|
+
print(f"pruning {name}")
|
|
190
|
+
pruned_param = unstructured_magnitude_prune_(
|
|
191
|
+
param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio
|
|
192
|
+
)
|
|
193
|
+
set_attr(
|
|
194
|
+
self.task_vectors,
|
|
195
|
+
name.split("."),
|
|
196
|
+
nn.Parameter(pruned_param.to_sparse(), requires_grad=False),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def topk_values_mask(self, M, K):
|
|
200
|
+
if K > 1:
|
|
201
|
+
K /= 100
|
|
202
|
+
|
|
203
|
+
original_shape = M.shape
|
|
204
|
+
if M.dim() == 1:
|
|
205
|
+
M = M.unsqueeze(0)
|
|
206
|
+
|
|
207
|
+
n, d = M.shape
|
|
208
|
+
k = int(d * K)
|
|
209
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
210
|
+
|
|
211
|
+
# Find the k-th smallest element by magnitude for each row
|
|
212
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
213
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
214
|
+
mask = M.abs() >= kth_values
|
|
215
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
216
|
+
|
|
217
|
+
return final_mask
|
|
218
|
+
|
|
219
|
+
def state_dict_to_vector(self, state_dict, remove_keys=[]):
|
|
220
|
+
"""
|
|
221
|
+
Convert a state dictionary to a vector, removing specified keys.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
state_dict (dict): The state dictionary to convert.
|
|
225
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
Tensor: A vector representation of the state dictionary.
|
|
229
|
+
"""
|
|
230
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
231
|
+
for key in remove_keys:
|
|
232
|
+
if key in shared_state_dict:
|
|
233
|
+
del shared_state_dict[key]
|
|
234
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
235
|
+
return nn.utils.parameters_to_vector(
|
|
236
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
|
|
240
|
+
"""
|
|
241
|
+
Convert a vector back to a state dictionary, removing specified keys.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
vector (Tensor): The vector to convert.
|
|
245
|
+
state_dict (dict): The reference state dictionary.
|
|
246
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
dict: A state dictionary representation of the vector.
|
|
250
|
+
"""
|
|
251
|
+
# create a reference dict to define the order of the vector
|
|
252
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
253
|
+
for key in remove_keys:
|
|
254
|
+
if key in reference_dict:
|
|
255
|
+
del reference_dict[key]
|
|
256
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
257
|
+
|
|
258
|
+
# create a shared state dict using the reference dict
|
|
259
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
260
|
+
|
|
261
|
+
# add back the encoder and decoder embedding weights.
|
|
262
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
263
|
+
for key in remove_keys:
|
|
264
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
265
|
+
"transformer.shared.weight"
|
|
266
|
+
]
|
|
267
|
+
return sorted_reference_dict
|
|
268
|
+
|
|
269
|
+
def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Computes the loss based on delta and task vectors for a specific layer.
|
|
272
|
+
"""
|
|
273
|
+
total_loss = 0.0
|
|
274
|
+
|
|
275
|
+
layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
|
|
276
|
+
sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
|
|
277
|
+
|
|
278
|
+
layer_delta_scale = layer_delta * layer_lamdas.view(-1, 1, 1)
|
|
279
|
+
sum_over_delta = layer_delta_scale.sum(dim=0)
|
|
280
|
+
|
|
281
|
+
# Iterate through each vector and calculate the loss one by one
|
|
282
|
+
for v_j in layer_vectors:
|
|
283
|
+
part1 = -v_j * sum_over_num_vectors
|
|
284
|
+
part2 = -v_j * sum_over_delta
|
|
285
|
+
part3 = v_j * v_j
|
|
286
|
+
|
|
287
|
+
expression = part1 + part2 + part3
|
|
288
|
+
layer_loss = expression.sum(dim=1).pow(2).sum()
|
|
289
|
+
|
|
290
|
+
# Cumulative total loss
|
|
291
|
+
total_loss += layer_loss
|
|
292
|
+
return total_loss
|
|
293
|
+
|
|
294
|
+
def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
|
|
295
|
+
lamdas = []
|
|
296
|
+
for vec in vectors:
|
|
297
|
+
tmp = {}
|
|
298
|
+
for layer_name in vec.keys():
|
|
299
|
+
norm_vec = torch.norm(vec[layer_name])
|
|
300
|
+
tmp[layer_name] = 0.07 / norm_vec
|
|
301
|
+
lamdas.append(tmp)
|
|
302
|
+
return lamdas
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def forward_model(self):
|
|
306
|
+
return functools.partial(
|
|
307
|
+
functional_call,
|
|
308
|
+
self.pretrained_model,
|
|
309
|
+
self._merged_state_dict,
|
|
310
|
+
tie_weights=self.tie_weights,
|
|
311
|
+
strict=self.strict,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
315
|
+
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
316
|
+
self.pretrained_model.load_state_dict(self._merged_state_dict)
|
|
317
|
+
return self.pretrained_model
|
|
318
|
+
|
|
319
|
+
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
320
|
+
"""
|
|
321
|
+
Merges the weights of the model.
|
|
322
|
+
Call this after each update step.
|
|
323
|
+
"""
|
|
324
|
+
if self.clamp_weights:
|
|
325
|
+
layer_wise_weight = self.merge_weight.clamp(0, 1)
|
|
326
|
+
else:
|
|
327
|
+
layer_wise_weight = self.merge_weight
|
|
328
|
+
if self.nromalized_merging_weights:
|
|
329
|
+
# normalize the weights for each layer, so that the sum of weights across models for each layer is 1.
|
|
330
|
+
layer_wise_weight = layer_wise_weight.softmax(dim=0)
|
|
331
|
+
|
|
332
|
+
state_dict = self.pretrained_model.state_dict(keep_vars=True)
|
|
333
|
+
# shape of layer_wise_weight: (num_models, num_layers)
|
|
334
|
+
for weight, task_vector in zip(layer_wise_weight, self.task_vectors):
|
|
335
|
+
task_vector_items = list(task_vector.items())
|
|
336
|
+
for w, (name, param) in zip(weight, task_vector_items):
|
|
337
|
+
state_dict[name] = state_dict[name] + param * w
|
|
338
|
+
self._merged_state_dict = state_dict
|
|
339
|
+
|
|
340
|
+
return state_dict
|
|
341
|
+
|
|
342
|
+
def forward(self, *args, **kwargs):
|
|
343
|
+
if self._merged_state_dict is None:
|
|
344
|
+
self.merge_weights()
|
|
345
|
+
return self.forward_model(args=args, kwargs=kwargs)
|
|
346
|
+
|
|
347
|
+
# def __getattr__(self, name: str) -> Any:
|
|
348
|
+
# try:
|
|
349
|
+
# return super().__getattr__(name)
|
|
350
|
+
# except AttributeError:
|
|
351
|
+
# attr = getattr(self.model, name)
|
|
352
|
+
# if isinstance(attr, Callable):
|
|
353
|
+
# warnings.warn(
|
|
354
|
+
# f"forwarding `{name}` to the underlying model", UserWarning
|
|
355
|
+
# )
|
|
356
|
+
# return attr
|
|
357
|
+
|
|
358
|
+
# def __setattr__(self, name: str, value: Any) -> None:
|
|
359
|
+
# try:
|
|
360
|
+
# super().__setattr__(name, value)
|
|
361
|
+
# except AttributeError:
|
|
362
|
+
# setattr(self.model, name, value)
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def merge_weights(module: nn.Module):
|
|
366
|
+
"""
|
|
367
|
+
Merges the weights for all `LayerWiseMergedModel` instances within the given module.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
module (nn.Module): The module to process.
|
|
371
|
+
"""
|
|
372
|
+
if isinstance(module, LayerWiseMergedModel):
|
|
373
|
+
module.merge_weights()
|
|
374
|
+
return
|
|
375
|
+
else:
|
|
376
|
+
for submodule in module.children():
|
|
377
|
+
merge_weights(submodule)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def merge_and_unload(module: nn.Module):
|
|
381
|
+
"""
|
|
382
|
+
Merges and unloads all `LayerWiseMergedModel` instances within the given module.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
module (nn.Module): The module to process.
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
nn.Module: The updated module with merged weights.
|
|
389
|
+
"""
|
|
390
|
+
if isinstance(module, LayerWiseMergedModel):
|
|
391
|
+
return module.merge_and_unload()
|
|
392
|
+
else:
|
|
393
|
+
for name, submodule in module.named_children():
|
|
394
|
+
need_merge = isinstance(submodule, LayerWiseMergedModel)
|
|
395
|
+
submodule = merge_and_unload(submodule)
|
|
396
|
+
if need_merge:
|
|
397
|
+
setattr(module, name, submodule)
|
|
398
|
+
return module
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def fix_other_parts(module: nn.Module):
|
|
402
|
+
"""
|
|
403
|
+
Sets all parameters in the module to not require gradients, except for the merge weights
|
|
404
|
+
in `LayerWiseMergedModel` instances.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
module (nn.Module): The module to process.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
nn.Module: The module with updated parameter requirements.
|
|
411
|
+
"""
|
|
412
|
+
module.requires_grad_(False)
|
|
413
|
+
for submodule in module.modules():
|
|
414
|
+
if isinstance(submodule, LayerWiseMergedModel):
|
|
415
|
+
submodule.merge_weight.requires_grad_(True)
|
|
416
|
+
return module
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.12
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,6 +45,7 @@ Requires-Dist: rich
|
|
|
45
45
|
Requires-Dist: scipy
|
|
46
46
|
Requires-Dist: h5py
|
|
47
47
|
Requires-Dist: pytest
|
|
48
|
+
Dynamic: license-file
|
|
48
49
|
|
|
49
50
|
<div align='center'>
|
|
50
51
|
|
|
@@ -69,6 +70,18 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
|
|
|
69
70
|
|
|
70
71
|
Projects based on FusionBench and news from the community (descending order of date):
|
|
71
72
|
|
|
73
|
+
<details>
|
|
74
|
+
<summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
|
|
75
|
+
|
|
76
|
+
Model merging has emerged as a promising approach for multi-task learning (MTL), offering a data-efficient alternative to conventional fine-tuning. However, with the rapid development of the open-source AI ecosystem and the increasing availability of fine-tuned foundation models, existing model merging methods face two key limitations: (i) They are primarily designed for in-house fine-tuned models, making them less adaptable to diverse model sources with partially unknown model and task information, (ii) They struggle to scale effectively when merging numerous model checkpoints. To address these challenges, we formulate model merging as a constrained optimization problem and introduce a novel approach: Frank-Wolfe Merging (FW-Merging). Inspired by Frank-Wolfe optimization, our approach iteratively selects the most relevant model in the pool to minimize a linear approximation of the objective function and then executes a local merging similar to the Frank-Wolfe update. The objective function is designed to capture the desired behavior of the target-merged model, while the fine-tuned candidate models define the constraint set. More importantly, FW-Merging serves as an orthogonal technique for existing merging methods, seamlessly integrating with them to further enhance accuracy performance. Our experiments show that FW-Merging scales across diverse model sources, remaining stable with 16 irrelevant models and improving by 15.3% with 16 relevant models on 20 CV tasks, while maintaining constant memory overhead, unlike the linear overhead of data-informed merging methods. Compared with the state-of-the-art approaches, FW-Merging surpasses the data-free merging method by 32.8% and outperforms the data-informed Adamerging by 8.39% when merging 20 ViT models.
|
|
77
|
+
</details>
|
|
78
|
+
|
|
79
|
+
<details>
|
|
80
|
+
<summary>Daniel Marczak, et al. No Task Left Behind: Isotropic Model Merging with Common and Task-Specific Subspaces. Feb 2025. https://arxiv.org/abs/2502.04959</summary>
|
|
81
|
+
|
|
82
|
+
Model merging integrates the weights of multiple task-specific models into a single multi-task model. Despite recent interest in the problem, a significant performance gap between the combined and single-task models remains. In this paper, we investigate the key characteristics of task matrices -- weight update matrices applied to a pre-trained model -- that enable effective merging. We show that alignment between singular components of task-specific and merged matrices strongly correlates with performance improvement over the pre-trained model. Based on this, we propose an isotropic merging framework that flattens the singular value spectrum of task matrices, enhances alignment, and reduces the performance gap. Additionally, we incorporate both common and task-specific subspaces to further improve alignment and performance. Our proposed approach achieves state-of-the-art performance across multiple scenarios, including various sets of tasks and model scales. This work advances the understanding of model merging dynamics, offering an effective methodology to merge models without requiring additional training.
|
|
83
|
+
</details>
|
|
84
|
+
|
|
72
85
|
<details>
|
|
73
86
|
<summary>Anke Tang, et al. Merging Models on the Fly Without Retraining: A Sequential Approach to Scalable Continual Model Merging. Jan 2025. https://arxiv.org/pdf/2501.09522</summary>
|
|
74
87
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
fusion_bench/__init__.py,sha256=68dF-zPvb8E2MgYnmgIJsxIHJBy1MApKeOrRZvQEVlg,421
|
|
2
2
|
fusion_bench/__main__.py,sha256=weUjxpP3ULnDgUxCehdbmoCM9cqfkhDhGB85tAF5qoE,81
|
|
3
3
|
fusion_bench/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
fusion_bench/compat/method/__init__.py,sha256=
|
|
4
|
+
fusion_bench/compat/method/__init__.py,sha256=qbm_0o4Y-X2FY3skmsQpYnKQ3qnR24Z0-uLOEnzO59M,5566
|
|
5
5
|
fusion_bench/compat/method/base_algorithm.py,sha256=63_AQDj1eJOO6RyTSGXVC6G2DsG8yg9E4pT3RJXgP3A,1952
|
|
6
6
|
fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py,sha256=m68BRGy4P-P9lLB10oXOBI-p58a-0FOPcrJ4r4MU32k,1100
|
|
7
7
|
fusion_bench/compat/modelpool/__init__.py,sha256=KD8Ddr9D7rJ5YdHEQsTuNmQ0bgQfqF4l3WNMtHmRHD8,4687
|
|
@@ -15,7 +15,7 @@ fusion_bench/constants/__init__.py,sha256=Pyc4dLbl6oNduOCdnpeXQ9LDyVoIrkdl9eZ_l2
|
|
|
15
15
|
fusion_bench/constants/paths.py,sha256=DVZyQ9FLhkyUdw6ARpXUCAMf_B8hFyJ6UNI-oYly3pE,591
|
|
16
16
|
fusion_bench/dataset/__init__.py,sha256=OJiYmcqz0Vm5O7mE4PB5QFJeL_KjrsseQTRsQATGTm4,1050
|
|
17
17
|
fusion_bench/dataset/clip_dataset.py,sha256=XLpCOiXlLEP3DffAlBn4P2PpUenbEFl-Yk9MNy6nbbI,2790
|
|
18
|
-
fusion_bench/dataset/fer2013.py,sha256=
|
|
18
|
+
fusion_bench/dataset/fer2013.py,sha256=bAdujQSj1PcUVFlKJgqcHAuE9AWz7JE1fzZ6scFVvmc,403
|
|
19
19
|
fusion_bench/dataset/gpt2_glue.py,sha256=Qq1ZkEIQsTjj8tImvkZDNlduocSYwlEfVrDReZqDWdw,8761
|
|
20
20
|
fusion_bench/dataset/gsm8k.py,sha256=CmANZ0A89PfPwVu_myKhXk1D9IwypOpjH3iqDo1KxcQ,2233
|
|
21
21
|
fusion_bench/dataset/image_dataset.py,sha256=MSZE_UESyRRQDwnkm2KpyIARUg9SWcwqnH4fDNstzS4,1870
|
|
@@ -41,7 +41,7 @@ fusion_bench/dataset/llama/stanford_shp.py,sha256=6ueXKnFXIBBobacU1h5WxGLZrSOtBk
|
|
|
41
41
|
fusion_bench/dataset/llama/ultrachat.py,sha256=Go7WvrDAYnm184fdazHGRYLbSY6Xd7jrESyQeUJtOww,1736
|
|
42
42
|
fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
|
|
43
43
|
fusion_bench/dataset/llama/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
fusion_bench/method/__init__.py,sha256=
|
|
44
|
+
fusion_bench/method/__init__.py,sha256=7S1ODkq2Zppx59o80qcIwDlRtfOC2EU58ooGFlDdJIU,7040
|
|
45
45
|
fusion_bench/method/base_algorithm.py,sha256=5dutGZfPqNhO8F8FOlo3UFR91TZu2Xj7O0pTB40JvWo,1135
|
|
46
46
|
fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
|
|
47
47
|
fusion_bench/method/ensemble.py,sha256=rGxvJTeorfcBuE_e0XO-0-MAc9un7ZCC46ikKGuAcN4,3077
|
|
@@ -50,7 +50,7 @@ fusion_bench/method/simple_average.py,sha256=2ghcL1E-eLbIYDCHYCoR9WtiYSb1GvFAH16
|
|
|
50
50
|
fusion_bench/method/ada_svd/__init__.py,sha256=4XzQbbvE9HI3NtEmEFvo8iC3ds_85vJXe7P7qJfL7kk,77
|
|
51
51
|
fusion_bench/method/ada_svd/clip_vision.py,sha256=QrT6cSwgVEGxXEpVhkvKQVQaoRW5P9V52Y3_8NX0f-o,12556
|
|
52
52
|
fusion_bench/method/adamerging/__init__.py,sha256=nt0saBT_3bqghk-pINQ-XCWm9UWwSZllu4R1sDuAJAA,376
|
|
53
|
-
fusion_bench/method/adamerging/clip_layer_wise_adamerging.py,sha256=
|
|
53
|
+
fusion_bench/method/adamerging/clip_layer_wise_adamerging.py,sha256=UUSldRPBxHVOfkMM7ZwqZay5Wjc6XQ3Vy9PgyqV_TZo,1311
|
|
54
54
|
fusion_bench/method/adamerging/clip_task_wise_adamerging.py,sha256=Tys9pDJzz5YNUCO43pO44fGAnizfSaeAwgH4-vVxRN4,6948
|
|
55
55
|
fusion_bench/method/adamerging/entropy_loss.py,sha256=ZeVe0Hq1PaMfppLqDbB0MOscZUZRNh4CALrvt8pmQC0,736
|
|
56
56
|
fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py,sha256=osc6ueCgiS4u8KUV_sZkHGFBYC8dThnTSp4NB0wkQIg,12915
|
|
@@ -66,9 +66,11 @@ fusion_bench/method/analysis/task_vector_violin_plot.py,sha256=ie8hPl6QsVz9MQ6C2
|
|
|
66
66
|
fusion_bench/method/classification/__init__.py,sha256=emB06UOMDHK5pfQ1WuvLG9Fm0aEEtZxSjpVw8fVE0fM,167
|
|
67
67
|
fusion_bench/method/classification/clip_finetune.py,sha256=DlV1isp8vz6jwXNYQ6zbblAoUfnssL-WBpDeaXI5BVw,15727
|
|
68
68
|
fusion_bench/method/classification/continual_clip_finetune.py,sha256=OLhZKS-6aCnafevZkZYcNMKTWDDj3DATB27eZl_i8EY,11530
|
|
69
|
-
fusion_bench/method/concrete_subspace/__init__.py,sha256=
|
|
69
|
+
fusion_bench/method/concrete_subspace/__init__.py,sha256=jJoFcjnQe-jvccsm9DuCXna378m9XBT9vV1fEZbdfR0,464
|
|
70
70
|
fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py,sha256=90_0HkOIl0XQG89xMa0UiBhrwfV2YqfLxlS04AouR3o,24755
|
|
71
71
|
fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py,sha256=Nx-3AiAeIt5zmcC21Ta2_-4cAQg9hOWvThurXNZzA-w,10580
|
|
72
|
+
fusion_bench/method/concrete_subspace/clip_post_defense.py,sha256=h-c0ioxDopg7pUoRjxx3epqQxVKZAZWz8s7yHjM88mg,32355
|
|
73
|
+
fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py,sha256=eEKKUBgHufYTBaWWxkIKDF0lkuLI2bBgNHVr1JqT41c,35694
|
|
72
74
|
fusion_bench/method/dare/__init__.py,sha256=63Xwkawyl_Ooy4xFxoDlP6wf-rgEWNqPuWTT9-6Ku5o,156
|
|
73
75
|
fusion_bench/method/dare/simple_average.py,sha256=jR08PokPIr5PWSZbGVOp3IApgKvxAIovg3vnB2KiTwk,906
|
|
74
76
|
fusion_bench/method/dare/task_arithmetic.py,sha256=Seno_2BhuogdRxXOni8alnHG-fdW15_OWoAvMoBoJj0,2780
|
|
@@ -81,12 +83,16 @@ fusion_bench/method/dawe/warppers/dawe_model.py,sha256=Z1L91vu3UzEHWrHs9i9UbwZpn
|
|
|
81
83
|
fusion_bench/method/depth_upscaling/__init__.py,sha256=heVUh4tTzK427A10RFknf9eHwoZ1cpn1_0xyNXRU7YM,135
|
|
82
84
|
fusion_bench/method/depth_upscaling/depth_upscaling.py,sha256=pf08zEae-WaWM4oUwn6_Dm65K59wf9AbTQ5iZU0ydsc,3256
|
|
83
85
|
fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py,sha256=bSMhnrG-JtR0JBnOFy7aWAhD6A-YBB84qm_YnWjc7pA,2180
|
|
86
|
+
fusion_bench/method/doge_ta/__init__.py,sha256=dixO0i5fmhgC_W2_DAQ4PzYnkMCZX5D8tDz84soqQ-Q,59
|
|
87
|
+
fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py,sha256=UUSldRPBxHVOfkMM7ZwqZay5Wjc6XQ3Vy9PgyqV_TZo,1311
|
|
88
|
+
fusion_bench/method/doge_ta/doge_ta.py,sha256=ec0qIq3F72nhbCVlfqdk1PYFM7QIlfMofeVFVvmDKiE,13785
|
|
89
|
+
fusion_bench/method/doge_ta/layer_wise_adamerging.py,sha256=rLk3Nep5d6wMUNCp6q7pC7L0pfBvUwGBIuiGM7CQOf4,9780
|
|
84
90
|
fusion_bench/method/fisher_merging/__init__.py,sha256=KWsjrtxKkPYwcUA5rB_6UNIqvesqk2NJw5AY_1ztLVE,225
|
|
85
91
|
fusion_bench/method/fisher_merging/clip_fisher_merging.py,sha256=QCutGqjkfW3OWETPZsCChqLRAhvfJp4QKD9TGSpTyV0,7635
|
|
86
92
|
fusion_bench/method/fisher_merging/fisher_merging.py,sha256=CPU-tJiDv9FCIBYl7Pn0zA5cdRB1Md5kWchRDlJgly0,20456
|
|
87
93
|
fusion_bench/method/fisher_merging/gpt2_fisher_merging.py,sha256=LZmz41jZ5dSsAHxfOUpr3u2rlCgUPTDR7xMsIlQM-jc,7576
|
|
88
|
-
fusion_bench/method/isotropic_merging/__init__.py,sha256=
|
|
89
|
-
fusion_bench/method/isotropic_merging/iso.py,sha256=
|
|
94
|
+
fusion_bench/method/isotropic_merging/__init__.py,sha256=0mxrl1UIjeFAPQcPcZtbgoCJO-DMW_49GKAhgcG-vEA,585
|
|
95
|
+
fusion_bench/method/isotropic_merging/iso.py,sha256=MwKqfk0oyxqtdOzeSx_9jFXX1a4Rd0WcEPsYvQhBSCg,3773
|
|
90
96
|
fusion_bench/method/isotropic_merging/iso_utils.py,sha256=7L8PYUIJROwHJQmhFY-tdEhkLAnzVKXr-ae55FQ1QSo,6928
|
|
91
97
|
fusion_bench/method/linear/__init__.py,sha256=ChfkoOEAb-rUKwpowFPel-a1hRfS8gCrbnWD-jlRbe4,283
|
|
92
98
|
fusion_bench/method/linear/expo.py,sha256=LCHTWlsPm1Mjhrq0mfpWLVC7skkI9ZksGduy3TxULoU,3939
|
|
@@ -103,9 +109,9 @@ fusion_bench/method/mixture_of_experts/__init__.py,sha256=r95iu1-3tgIUP7sWuAbLuq
|
|
|
103
109
|
fusion_bench/method/mixture_of_experts/mixtral_merging.py,sha256=-n1CLP1o08VyMSfaTq42kRutbw-cFDSCWHTu0iNh6ok,4237
|
|
104
110
|
fusion_bench/method/mixture_of_experts/mixtral_upcycling.py,sha256=tQYAeS8MLFEfH3zDFfNZrML7lRnpGLN-HquQvjPtHNw,11208
|
|
105
111
|
fusion_bench/method/opcm/__init__.py,sha256=0QcltOnjIYV1XEPDEagChLixLAhjiBnYwfWK00am29k,202
|
|
106
|
-
fusion_bench/method/opcm/opcm.py,sha256
|
|
107
|
-
fusion_bench/method/opcm/task_arithmetic.py,sha256=
|
|
108
|
-
fusion_bench/method/opcm/ties_merging.py,sha256
|
|
112
|
+
fusion_bench/method/opcm/opcm.py,sha256=-sqfK5q_-yr_3YWigmXKVYRP1J7swHOR9eGMMzu1Dgw,11445
|
|
113
|
+
fusion_bench/method/opcm/task_arithmetic.py,sha256=YvtsWkjtnk7E3C4_xNr--uQWjQhoDZZB-klSx81_tGw,4824
|
|
114
|
+
fusion_bench/method/opcm/ties_merging.py,sha256=-N3i7eMbhK95qyJsmmNMKNmPCkgGHGFa423a52cgi6g,6868
|
|
109
115
|
fusion_bench/method/opcm/utils.py,sha256=_q7yy3ENNFUh1qUd5J5DThRL4J1tIxEcknCO2AKmeYM,2102
|
|
110
116
|
fusion_bench/method/opcm/weight_average.py,sha256=JfQoIU5J1jvrNKpO9k_t4Zj0y8PtteIfyoSQWx1yg2k,4379
|
|
111
117
|
fusion_bench/method/pruning/__init__.py,sha256=3gtmay2bkdIAEGjpAhbY2ztMZOZLKhiJcKV3mCe2H5w,252
|
|
@@ -154,7 +160,7 @@ fusion_bench/method/tall_mask/utils.py,sha256=Wlp8WcPwR_lCaBIZ9rgG6ewLfSzz3G7kPk
|
|
|
154
160
|
fusion_bench/method/task_arithmetic/__init__.py,sha256=pSx_NV5Ra_6UXpyYWCi6ANQoAnEtymZt_X1dDN9wT4Y,96
|
|
155
161
|
fusion_bench/method/task_arithmetic/task_arithmetic.py,sha256=1D0uuNtqyA1VS35jh6AnEVsX72HnT02THyerck_lmso,5441
|
|
156
162
|
fusion_bench/method/task_singular_vector/TSVC.py,sha256=yn4SrZNvtA6PoGYJmbmtNeDyDbGnRCgfZ7ZCg914AZU,410
|
|
157
|
-
fusion_bench/method/task_singular_vector/TSVM.py,sha256=
|
|
163
|
+
fusion_bench/method/task_singular_vector/TSVM.py,sha256=H5RzZlQQeF4kZFjuxkz8v3gyVKS3iKPgqNnitKQzbXk,2787
|
|
158
164
|
fusion_bench/method/task_singular_vector/__init__.py,sha256=WMucyl9pu_Ev2kcdrfT4moqMMbzD7hHQVFME5Su5jMA,298
|
|
159
165
|
fusion_bench/method/task_singular_vector/utils/TSVC_utils.py,sha256=FytKbal48EW6iGIA-2zV7QSVbYTVflXr4Mr56q0W75k,2286
|
|
160
166
|
fusion_bench/method/task_singular_vector/utils/TSVM_utils.py,sha256=dsTMQ15zFJ1MPqDOt2TJ01O9Bwq_klyG9xL9hRD2aI0,27521
|
|
@@ -253,7 +259,8 @@ fusion_bench/models/surgery/__init__.py,sha256=tcUSi2m9GzGWfvRDQScIbdEbFBS_35gm9
|
|
|
253
259
|
fusion_bench/models/surgery/surgerymodelwrapper.py,sha256=F8jX88K5zVWC6HsfN-nGNkEiPwNrN11ydyQQ1EZHehM,5133
|
|
254
260
|
fusion_bench/models/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
255
261
|
fusion_bench/models/wrappers/ensemble.py,sha256=wIMZMRyXw5boWAm96c4Tiyebs_HDQovKxpGQ8rLnHUQ,6308
|
|
256
|
-
fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=
|
|
262
|
+
fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=KamNaq4DlyxQrOp1i9aQLgA2WX81YD5NhzAQ5GF6rg0,11188
|
|
263
|
+
fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py,sha256=q5Hc4BtLpAawMbxsWJRL-8OR-x7994Jhr9IyN7vKZ9o,16930
|
|
257
264
|
fusion_bench/models/wrappers/task_wise_fusion.py,sha256=Wn3buQvWw_lihWaKB03_iz34cBPzwBD94kBT6uafWVQ,8404
|
|
258
265
|
fusion_bench/optim/__init__.py,sha256=lemrcuiA6OLjQkpYm-RP-Ox2MgjngN1ywvCo0NgShlM,61
|
|
259
266
|
fusion_bench/optim/exception.py,sha256=fMgo1heiqfGhuI5RIbf30BwWSShn5RQiyeb30QtfTI0,1607
|
|
@@ -352,6 +359,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
|
|
|
352
359
|
fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
|
|
353
360
|
fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
|
|
354
361
|
fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
|
|
362
|
+
fusion_bench-0.2.12.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
|
|
355
363
|
fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
|
|
356
364
|
fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=GtK3VuD2FOpFHH_1Hi6tlaYpdLE5Cz0nYKP92Ss9G2Y,1164
|
|
357
365
|
fusion_bench_config/fabric_model_fusion.yaml,sha256=1shmbuC0B9snkFkLErBCiroF-z7UnEHscyEmKBne7Oo,949
|
|
@@ -477,10 +485,15 @@ fusion_bench_config/method/classification/clip_finetune.yaml,sha256=yWjcdKYaKvy5
|
|
|
477
485
|
fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml,sha256=XsHzr_5NoUZs0Us3eVwP3lUYXYvyJwGEEG9aDI_Z0rU,740
|
|
478
486
|
fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml,sha256=eNoqcY1iMbs0Y5kKi_ya3rmQQMHqU7ht3EU7G_xmwN0,746
|
|
479
487
|
fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml,sha256=WgTJj28FlVjR0_mCGJC5B8aJa9yezI3QusoXXHOrFoU,739
|
|
488
|
+
fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml,sha256=eGUCntXzDtW0tYX1vij7BHgDWzWq6sz2yFipVZj6z9E,849
|
|
489
|
+
fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml,sha256=DUYOU5A8MQw2cTqbraIDMFC7ciO8RXE2qXgVEEUudLM,891
|
|
490
|
+
fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml,sha256=olDW_p5gyyaynwbGAQgm2ZicYAx9n3i4FprxPecuUsU,923
|
|
491
|
+
fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml,sha256=KLO3C1BdeB6FBKHT0xG4V0OFk7ib2SeMScKeaN5BlsU,863
|
|
480
492
|
fusion_bench_config/method/dare/simple_average.yaml,sha256=oTFSCHul86NTjTtJYK5pNr3tuxW7XxNI-y6fL9Yo4VI,113
|
|
481
493
|
fusion_bench_config/method/dare/task_arithmetic.yaml,sha256=Cvsam89yquamn_GkITT6q8qFKN_Yb5nv8p-XgvnVrgU,134
|
|
482
494
|
fusion_bench_config/method/dare/ties_merging.yaml,sha256=50mPiRkzLN7gxaIs56sPWkAUSvqvdxjQJ8eVl1yUGOg,418
|
|
483
495
|
fusion_bench_config/method/dawe/dawe_for_clip.yaml,sha256=8-Z_kwwGCy1AO4brW-R_pe8oJ0yqoD4WCLI9ZtJ4KOo,1026
|
|
496
|
+
fusion_bench_config/method/doge_ta/doge_ta.yaml,sha256=6R9NRuWmj0oapJ_raMB6R6rZPMckt2JtMLrTQ6HhrFc,77
|
|
484
497
|
fusion_bench_config/method/ensemble/max_model_predictor.yaml,sha256=fsWuNJwr1ohVB2aJ5L2fsiDLztm5GieE9JS99w--two,56
|
|
485
498
|
fusion_bench_config/method/ensemble/simple_ensemble.yaml,sha256=bw9FabjhQYNbttsiMgTVd-Z4KIowf050Uy97vKtm2ys,55
|
|
486
499
|
fusion_bench_config/method/ensemble/weighted_ensemble.yaml,sha256=U_wQXtogtgiqOTszHUgcGNfrKlXD6JrR_HjqNwAkkKo,262
|
|
@@ -679,7 +692,8 @@ fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml,sha256=aX0rWw
|
|
|
679
692
|
fusion_bench_config/modelpool/Seq2SeqLMPool/_template.yaml,sha256=mRx-Xx4s6_IBoJJRogIBW4egmqW0wi1kGVWp_YwYVvQ,233
|
|
680
693
|
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml,sha256=6Rgfq3cjCRWbAL8Bb-Dkvl9eJP4FKmqewBpokajwYWU,335
|
|
681
694
|
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml,sha256=1vaVb059Wh3XMD8MhXD9p5a0zx8mi9HovOcS0k51uK8,1699
|
|
682
|
-
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml,sha256=
|
|
695
|
+
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml,sha256=dwBb3wPfyxH6cx6txBd31OOlrfCvPkM-nIN46FJer-I,1790
|
|
696
|
+
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml,sha256=2BBuK1uyKL_9uo3X3bScjZiK-PtIiE_7RHj4onK_3R0,1725
|
|
683
697
|
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml,sha256=2YBIzqYGluOT2r6dOFpUYE4Cbdd2XoHAUps-kCDxVPQ,185
|
|
684
698
|
fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml,sha256=W1y3fKY9UTTRyv7nqbIO5DESlQVfNsWlhkHJMUYh7B4,1824
|
|
685
699
|
fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml,sha256=JUzGOLANW92Y_rljOOZKmwBQvWrJsko_ziayurzHSTY,880
|
|
@@ -724,9 +738,8 @@ fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397
|
|
|
724
738
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml,sha256=2AqMiNCRRunLIrssHvFzu1lUzOaQn8uOHM9yjrQq-_A,109
|
|
725
739
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml,sha256=iQMj2VpDTe_D8OfCo94w5Ud2MON-EGa0DzVr6UmphrA,436
|
|
726
740
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml,sha256=i5Bn8bLl2cgqvrgtIGmoovUfSMehk_m-6C2wwcx5JMU,435
|
|
727
|
-
fusion_bench-0.2.
|
|
728
|
-
fusion_bench-0.2.
|
|
729
|
-
fusion_bench-0.2.
|
|
730
|
-
fusion_bench-0.2.
|
|
731
|
-
fusion_bench-0.2.
|
|
732
|
-
fusion_bench-0.2.10.dist-info/RECORD,,
|
|
741
|
+
fusion_bench-0.2.12.dist-info/METADATA,sha256=V0KZSil6pMjhZVA3x0wUrW-eskY5DsyclRkiuh8sfec,20085
|
|
742
|
+
fusion_bench-0.2.12.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
|
743
|
+
fusion_bench-0.2.12.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
|
|
744
|
+
fusion_bench-0.2.12.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
|
|
745
|
+
fusion_bench-0.2.12.dist-info/RECORD,,
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Reference: Jinluan Yang, et al. Mitigating the Backdoor Effect for Multi-Task Model Merging via Safety-Aware Subspace. ICLR 2025.
|
|
2
|
+
|
|
3
|
+
name: clip_post_defense_AWM
|
|
4
|
+
|
|
5
|
+
# batch size per gpu
|
|
6
|
+
# if you have multiple gpus, the total batch size will be `batch_size * num_gpus`
|
|
7
|
+
batch_size: 16
|
|
8
|
+
num_workers: 8
|
|
9
|
+
|
|
10
|
+
optimizer: adam
|
|
11
|
+
lr: 1e-3
|
|
12
|
+
|
|
13
|
+
scaling_factor: 0.3
|
|
14
|
+
|
|
15
|
+
###new
|
|
16
|
+
adv_lr: 1e-4
|
|
17
|
+
trigger_norm: 1000
|
|
18
|
+
adv_weight: 0.01
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
max_steps: 2000
|
|
22
|
+
save_interval: 500
|
|
23
|
+
initial_logits: 0
|
|
24
|
+
temperature: 0.5
|
|
25
|
+
|
|
26
|
+
# "discrete" or "continuous", this is the mask applied for evaluation, not during training
|
|
27
|
+
# the performance of final model are expected to be similar
|
|
28
|
+
eval_mask_type: continuous
|
|
29
|
+
|
|
30
|
+
mask_checkpoint: null
|
|
31
|
+
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
|
|
32
|
+
clamp_weights: false
|
|
33
|
+
|
|
34
|
+
# arguments of `functional_call`
|
|
35
|
+
tie_weights: true
|
|
36
|
+
strict: false
|
|
37
|
+
|
|
38
|
+
cache_dir: outputs
|