fusion-bench 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
- fusion_bench/method/DOGE_TA/__init__.py +2 -0
- fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
- fusion_bench/method/__init__.py +22 -0
- fusion_bench/method/classification/continual_clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +30 -13
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
|
|
7
|
+
|
|
8
|
+
import lightning as L
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, nn
|
|
11
|
+
from torch.func import functional_call
|
|
12
|
+
|
|
13
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add
|
|
14
|
+
from fusion_bench.utils.type import StateDictType
|
|
15
|
+
|
|
16
|
+
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def del_attr(obj, names: List[str]):
|
|
22
|
+
"""
|
|
23
|
+
Deletes an attribute from an object recursively.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
obj (object): Object to delete attribute from.
|
|
27
|
+
names (list): List of attribute names to delete recursively.
|
|
28
|
+
"""
|
|
29
|
+
if len(names) == 1:
|
|
30
|
+
delattr(obj, names[0])
|
|
31
|
+
else:
|
|
32
|
+
del_attr(getattr(obj, names[0]), names[1:])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def set_attr(obj, names: List[str], val):
|
|
36
|
+
"""
|
|
37
|
+
Sets an attribute of an object recursively.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
obj (object): Object to set attribute of.
|
|
41
|
+
names (list): List of attribute names to set recursively.
|
|
42
|
+
val (object): Value to set the attribute to.
|
|
43
|
+
"""
|
|
44
|
+
if len(names) == 1:
|
|
45
|
+
setattr(obj, names[0], val)
|
|
46
|
+
else:
|
|
47
|
+
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_attr(obj, names: List[str]):
|
|
51
|
+
"""
|
|
52
|
+
Gets an attribute of an object recursively.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
obj (object): Object to get attribute of.
|
|
56
|
+
names (list): List of attribute names to get recursively.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
object: The attribute of the object.
|
|
60
|
+
"""
|
|
61
|
+
if len(names) == 1:
|
|
62
|
+
return getattr(obj, names[0])
|
|
63
|
+
else:
|
|
64
|
+
return get_attr(getattr(obj, names[0]), names[1:])
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def get_layer_wise_weights(
|
|
68
|
+
num_models: int,
|
|
69
|
+
num_layers: int,
|
|
70
|
+
init_values: float = None,
|
|
71
|
+
dtype: torch.dtype = torch.float32,
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Return a tensor of layer-wise weights for the given number of models and layers.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
num_models (int): The number of models to fuse.
|
|
78
|
+
num_layers (int): The number of layers in each model.
|
|
79
|
+
init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
|
|
80
|
+
dtype (torch.dtype): dtype of weights. This should be the same with model dtype.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
|
|
84
|
+
"""
|
|
85
|
+
assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
|
|
86
|
+
assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
|
|
87
|
+
if init_values is None:
|
|
88
|
+
init_values = 1.0 / num_models
|
|
89
|
+
return torch.full((num_models, num_layers), init_values, dtype=dtype)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
|
|
93
|
+
"""
|
|
94
|
+
Fuse the layer-wise weights with the given state dictionaries.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights.
|
|
98
|
+
state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tensor: A tensor of shape (num_params,) containing the fused weights.
|
|
102
|
+
"""
|
|
103
|
+
assert len(layer_wise_weight) == len(
|
|
104
|
+
tensors
|
|
105
|
+
), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}"
|
|
106
|
+
return sum(
|
|
107
|
+
layer_wise_weight[i] * w.to(layer_wise_weight.device)
|
|
108
|
+
for i, w in enumerate(tensors)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def fuse_weights(
|
|
113
|
+
layer_wise_weight: Tensor, state_dicts: List[StateDictType]
|
|
114
|
+
) -> StateDictType:
|
|
115
|
+
"""
|
|
116
|
+
Fuse the weights of multiple models using layer-wise fusion.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
|
|
120
|
+
state_dicts (List[StateDict]): A list of state dictionaries, one for each model.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A dictionary mapping each weight tensor key to the fused weight tensor.
|
|
124
|
+
"""
|
|
125
|
+
num_models = len(state_dicts)
|
|
126
|
+
num_layers = len(state_dicts[0])
|
|
127
|
+
assert layer_wise_weight.shape == (
|
|
128
|
+
num_models,
|
|
129
|
+
num_layers,
|
|
130
|
+
), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
|
|
131
|
+
return {
|
|
132
|
+
k: _fuse_weights(
|
|
133
|
+
layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
|
|
134
|
+
)
|
|
135
|
+
for i, k in enumerate(state_dicts[0].keys())
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class LayerWiseMergedModel(nn.Module):
|
|
140
|
+
_merged_state_dict: StateDictType = None
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
layer_wise_weight: Tensor,
|
|
145
|
+
pretrained_model: nn.Module,
|
|
146
|
+
finetuned_models: List[nn.Module],
|
|
147
|
+
clamp_weights: bool = True,
|
|
148
|
+
tie_weights: bool = False,
|
|
149
|
+
strict: bool = True,
|
|
150
|
+
sparsity_ratio: Optional[float] = None,
|
|
151
|
+
normalized_merging_weights: bool = False,
|
|
152
|
+
):
|
|
153
|
+
R"""
|
|
154
|
+
This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
|
|
158
|
+
pretrained_model (nn.Module): The pretrained model to merge the weights into.
|
|
159
|
+
finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
|
|
160
|
+
clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
|
|
161
|
+
tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False.
|
|
162
|
+
strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True.
|
|
163
|
+
sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.
|
|
164
|
+
normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
|
|
165
|
+
"""
|
|
166
|
+
super().__init__()
|
|
167
|
+
if torch.cuda.is_available():
|
|
168
|
+
self._fabric = L.Fabric(devices=1)
|
|
169
|
+
self._fabric.launch()
|
|
170
|
+
self.clamp_weights = clamp_weights
|
|
171
|
+
self.tie_weights = tie_weights
|
|
172
|
+
self.strict = strict
|
|
173
|
+
self.sparsity_ratio = sparsity_ratio
|
|
174
|
+
self.nromalized_merging_weights = normalized_merging_weights
|
|
175
|
+
|
|
176
|
+
pretrained_sd = pretrained_model.state_dict(keep_vars=True)
|
|
177
|
+
filtered_keys = [
|
|
178
|
+
k
|
|
179
|
+
for k in pretrained_sd.keys()
|
|
180
|
+
if ("encoder" in k and "layer_norm" not in k and "weight" in k)
|
|
181
|
+
]
|
|
182
|
+
self.merge_weight = nn.Parameter(
|
|
183
|
+
layer_wise_weight[:, : len(filtered_keys)], requires_grad=True
|
|
184
|
+
)
|
|
185
|
+
task_vectors = []
|
|
186
|
+
for m in finetuned_models:
|
|
187
|
+
m.requires_grad_(False)
|
|
188
|
+
self.pretrained_model = pretrained_model.requires_grad_(False)
|
|
189
|
+
for model in finetuned_models:
|
|
190
|
+
model_sd = model.state_dict(keep_vars=True)
|
|
191
|
+
filtered_task_vector = {
|
|
192
|
+
k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
|
|
193
|
+
}
|
|
194
|
+
if self._fabric is not None:
|
|
195
|
+
filtered_task_vector = self._fabric.to_device(filtered_task_vector)
|
|
196
|
+
task_vectors.append(filtered_task_vector)
|
|
197
|
+
|
|
198
|
+
self.projection = {}
|
|
199
|
+
for layer_name in task_vectors[0].keys():
|
|
200
|
+
for i, vector in enumerate(task_vectors):
|
|
201
|
+
layer_vector = vector[layer_name]
|
|
202
|
+
u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
|
|
203
|
+
if i == 0:
|
|
204
|
+
print(f"Computed SVD for {layer_name}...")
|
|
205
|
+
sum_u = torch.zeros_like(u, device=layer_vector.device)
|
|
206
|
+
sum_s = torch.zeros_like(s, device=layer_vector.device)
|
|
207
|
+
sum_v = torch.zeros_like(v, device=layer_vector.device)
|
|
208
|
+
|
|
209
|
+
reduced_index_s = int(s.shape[0] / len(task_vectors))
|
|
210
|
+
|
|
211
|
+
# select only the first reduced_index_s columns of u and place them
|
|
212
|
+
sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
|
|
213
|
+
:, :reduced_index_s
|
|
214
|
+
]
|
|
215
|
+
sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
|
|
216
|
+
:reduced_index_s
|
|
217
|
+
]
|
|
218
|
+
# select only the first reduced_index_s rows of v and place them
|
|
219
|
+
sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
|
|
220
|
+
:reduced_index_s, :
|
|
221
|
+
]
|
|
222
|
+
u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
|
|
223
|
+
# u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
|
|
224
|
+
layer_proj = torch.matmul(
|
|
225
|
+
u_u[:, : int(s.shape[0] / len(task_vectors))],
|
|
226
|
+
u_u[:, : int(s.shape[0] / len(task_vectors))].T,
|
|
227
|
+
)
|
|
228
|
+
self.projection[layer_name] = layer_proj
|
|
229
|
+
|
|
230
|
+
self.delta = [
|
|
231
|
+
{
|
|
232
|
+
k: torch.zeros_like(v).clone().requires_grad_()
|
|
233
|
+
for k, v in task_vector.items()
|
|
234
|
+
}
|
|
235
|
+
for task_vector in task_vectors
|
|
236
|
+
]
|
|
237
|
+
if self._fabric is not None:
|
|
238
|
+
self.delta = self._fabric.to_device(self.delta)
|
|
239
|
+
self.lamdas = self.compute_layer_lamdas(task_vectors)
|
|
240
|
+
|
|
241
|
+
for layer_name in task_vectors[0].keys():
|
|
242
|
+
optimizer = torch.optim.Adam(
|
|
243
|
+
[delta[layer_name] for delta in self.delta], lr=1e-4
|
|
244
|
+
)
|
|
245
|
+
layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors])
|
|
246
|
+
layer_lamdas = torch.stack([lamdas[layer_name] for lamdas in self.lamdas])
|
|
247
|
+
for _ in range(400):
|
|
248
|
+
optimizer.zero_grad()
|
|
249
|
+
layer_delta = torch.stack([de[layer_name] for de in self.delta])
|
|
250
|
+
loss = self.taskvector_loss(layer_vectors, layer_delta, layer_lamdas)
|
|
251
|
+
print(f"Epoch: {_}, Layer: {layer_name}, Loss: {loss.item()}")
|
|
252
|
+
self._fabric.backward(loss)
|
|
253
|
+
for delta in self.delta:
|
|
254
|
+
grad_proj = (
|
|
255
|
+
self.projection[layer_name] @ delta[layer_name].grad.detach()
|
|
256
|
+
)
|
|
257
|
+
delta[layer_name].grad.data = delta[layer_name].grad.data.sub_(
|
|
258
|
+
grad_proj
|
|
259
|
+
)
|
|
260
|
+
optimizer.step()
|
|
261
|
+
for delta in self.delta:
|
|
262
|
+
for param in delta.values():
|
|
263
|
+
param.grad = None
|
|
264
|
+
del self.projection
|
|
265
|
+
self.delta = [
|
|
266
|
+
{key: param.detach().cpu() for key, param in delta.items()}
|
|
267
|
+
for delta in self.delta
|
|
268
|
+
]
|
|
269
|
+
self.lamdas = [
|
|
270
|
+
{key: param.cpu() for key, param in lamdas.items()}
|
|
271
|
+
for lamdas in self.lamdas
|
|
272
|
+
]
|
|
273
|
+
task_vectors = [
|
|
274
|
+
{k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
|
|
275
|
+
]
|
|
276
|
+
flat_vectors = []
|
|
277
|
+
vector_masks = []
|
|
278
|
+
for idx, task_vector in enumerate(task_vectors):
|
|
279
|
+
flat_vector = self.state_dict_to_vector(task_vector)
|
|
280
|
+
vector_mask = self.topk_values_mask(flat_vector, K=30)
|
|
281
|
+
flat_vectors.append(flat_vector)
|
|
282
|
+
vector_masks.append(vector_mask)
|
|
283
|
+
flat_deltas = [self.state_dict_to_vector(delta) for delta in self.delta]
|
|
284
|
+
self.task_vectors = [
|
|
285
|
+
self.vector_to_state_dict(
|
|
286
|
+
(flat_vector + flat_delta) * vector_mask, self.delta[0]
|
|
287
|
+
)
|
|
288
|
+
for flat_vector, flat_delta, vector_mask in zip(
|
|
289
|
+
flat_vectors, flat_deltas, vector_masks
|
|
290
|
+
)
|
|
291
|
+
]
|
|
292
|
+
if self._fabric is not None:
|
|
293
|
+
self.task_vectors = self._fabric.to_device(self.task_vectors)
|
|
294
|
+
|
|
295
|
+
# if `sparisty_ratio` is given, pruning the task vectors.
|
|
296
|
+
if sparsity_ratio is not None:
|
|
297
|
+
from fusion_bench.method.pruning.prune_utils import (
|
|
298
|
+
unstructured_magnitude_prune_,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
for name, param in self.task_vectors.named_parameters():
|
|
302
|
+
if param.dim() != 2:
|
|
303
|
+
continue
|
|
304
|
+
print(f"pruning {name}")
|
|
305
|
+
pruned_param = unstructured_magnitude_prune_(
|
|
306
|
+
param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio
|
|
307
|
+
)
|
|
308
|
+
set_attr(
|
|
309
|
+
self.task_vectors,
|
|
310
|
+
name.split("."),
|
|
311
|
+
nn.Parameter(pruned_param.to_sparse(), requires_grad=False),
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def topk_values_mask(self, M, K):
|
|
315
|
+
if K > 1:
|
|
316
|
+
K /= 100
|
|
317
|
+
|
|
318
|
+
original_shape = M.shape
|
|
319
|
+
if M.dim() == 1:
|
|
320
|
+
M = M.unsqueeze(0)
|
|
321
|
+
|
|
322
|
+
n, d = M.shape
|
|
323
|
+
k = int(d * K)
|
|
324
|
+
k = d - k # Keep top k elements instead of bottom k elements
|
|
325
|
+
|
|
326
|
+
# Find the k-th smallest element by magnitude for each row
|
|
327
|
+
kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
|
|
328
|
+
# Create a mask tensor with True for the top k elements in each row
|
|
329
|
+
mask = M.abs() >= kth_values
|
|
330
|
+
final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
|
|
331
|
+
|
|
332
|
+
return final_mask
|
|
333
|
+
|
|
334
|
+
def state_dict_to_vector(self, state_dict, remove_keys=[]):
|
|
335
|
+
"""
|
|
336
|
+
Convert a state dictionary to a vector, removing specified keys.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
state_dict (dict): The state dictionary to convert.
|
|
340
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
Tensor: A vector representation of the state dictionary.
|
|
344
|
+
"""
|
|
345
|
+
shared_state_dict = copy.deepcopy(state_dict)
|
|
346
|
+
for key in remove_keys:
|
|
347
|
+
if key in shared_state_dict:
|
|
348
|
+
del shared_state_dict[key]
|
|
349
|
+
sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
|
|
350
|
+
return nn.utils.parameters_to_vector(
|
|
351
|
+
[value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
|
|
355
|
+
"""
|
|
356
|
+
Convert a vector back to a state dictionary, removing specified keys.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
vector (Tensor): The vector to convert.
|
|
360
|
+
state_dict (dict): The reference state dictionary.
|
|
361
|
+
remove_keys (list): List of keys to remove from the state dictionary.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
dict: A state dictionary representation of the vector.
|
|
365
|
+
"""
|
|
366
|
+
# create a reference dict to define the order of the vector
|
|
367
|
+
reference_dict = copy.deepcopy(state_dict)
|
|
368
|
+
for key in remove_keys:
|
|
369
|
+
if key in reference_dict:
|
|
370
|
+
del reference_dict[key]
|
|
371
|
+
sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
|
|
372
|
+
|
|
373
|
+
# create a shared state dict using the reference dict
|
|
374
|
+
nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
|
|
375
|
+
|
|
376
|
+
# add back the encoder and decoder embedding weights.
|
|
377
|
+
if "transformer.shared.weight" in sorted_reference_dict:
|
|
378
|
+
for key in remove_keys:
|
|
379
|
+
sorted_reference_dict[key] = sorted_reference_dict[
|
|
380
|
+
"transformer.shared.weight"
|
|
381
|
+
]
|
|
382
|
+
return sorted_reference_dict
|
|
383
|
+
|
|
384
|
+
def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
|
|
385
|
+
"""
|
|
386
|
+
Computes the loss based on delta and task vectors for a specific layer.
|
|
387
|
+
"""
|
|
388
|
+
total_loss = 0.0
|
|
389
|
+
|
|
390
|
+
layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
|
|
391
|
+
sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
|
|
392
|
+
|
|
393
|
+
layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
|
|
394
|
+
sum_over_delta = layer_delta_scale.sum(dim=0)
|
|
395
|
+
|
|
396
|
+
# Iterate through each vector and calculate the loss one by one
|
|
397
|
+
for v_j in layer_vectors:
|
|
398
|
+
part1 = -v_j * sum_over_num_vectors
|
|
399
|
+
part2 = -v_j * sum_over_delta
|
|
400
|
+
part3 = v_j * v_j
|
|
401
|
+
|
|
402
|
+
expression = part1 + part2 + part3
|
|
403
|
+
layer_loss = expression.sum(dim=1).pow(2).sum()
|
|
404
|
+
|
|
405
|
+
# Cumulative total loss
|
|
406
|
+
total_loss += layer_loss
|
|
407
|
+
return total_loss
|
|
408
|
+
|
|
409
|
+
def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
|
|
410
|
+
lamdas = []
|
|
411
|
+
for vec in vectors:
|
|
412
|
+
tmp = {}
|
|
413
|
+
for layer_name in vec.keys():
|
|
414
|
+
norm_vec = torch.norm(vec[layer_name])
|
|
415
|
+
tmp[layer_name] = 0.07 / norm_vec
|
|
416
|
+
lamdas.append(tmp)
|
|
417
|
+
return lamdas
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def forward_model(self):
|
|
421
|
+
return functools.partial(
|
|
422
|
+
functional_call,
|
|
423
|
+
self.pretrained_model,
|
|
424
|
+
self._merged_state_dict,
|
|
425
|
+
tie_weights=self.tie_weights,
|
|
426
|
+
strict=self.strict,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
430
|
+
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
431
|
+
self.pretrained_model.load_state_dict(self._merged_state_dict)
|
|
432
|
+
return self.pretrained_model
|
|
433
|
+
|
|
434
|
+
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
435
|
+
"""
|
|
436
|
+
Merges the weights of the model.
|
|
437
|
+
Call this after each update step.
|
|
438
|
+
"""
|
|
439
|
+
if self.clamp_weights:
|
|
440
|
+
layer_wise_weight = self.merge_weight.clamp(0, 1)
|
|
441
|
+
else:
|
|
442
|
+
layer_wise_weight = self.merge_weight
|
|
443
|
+
if self.nromalized_merging_weights:
|
|
444
|
+
# normalize the weights for each layer, so that the sum of weights across models for each layer is 1.
|
|
445
|
+
layer_wise_weight = layer_wise_weight.softmax(dim=0)
|
|
446
|
+
|
|
447
|
+
state_dict = self.pretrained_model.state_dict(keep_vars=True)
|
|
448
|
+
# shape of layer_wise_weight: (num_models, num_layers)
|
|
449
|
+
for weight, task_vector in zip(layer_wise_weight, self.task_vectors):
|
|
450
|
+
task_vector_items = list(task_vector.items())
|
|
451
|
+
for w, (name, param) in zip(weight, task_vector_items):
|
|
452
|
+
state_dict[name] = state_dict[name] + param * w
|
|
453
|
+
self._merged_state_dict = state_dict
|
|
454
|
+
|
|
455
|
+
return state_dict
|
|
456
|
+
|
|
457
|
+
def forward(self, *args, **kwargs):
|
|
458
|
+
if self._merged_state_dict is None:
|
|
459
|
+
self.merge_weights()
|
|
460
|
+
return self.forward_model(args=args, kwargs=kwargs)
|
|
461
|
+
|
|
462
|
+
# def __getattr__(self, name: str) -> Any:
|
|
463
|
+
# try:
|
|
464
|
+
# return super().__getattr__(name)
|
|
465
|
+
# except AttributeError:
|
|
466
|
+
# attr = getattr(self.model, name)
|
|
467
|
+
# if isinstance(attr, Callable):
|
|
468
|
+
# warnings.warn(
|
|
469
|
+
# f"forwarding `{name}` to the underlying model", UserWarning
|
|
470
|
+
# )
|
|
471
|
+
# return attr
|
|
472
|
+
|
|
473
|
+
# def __setattr__(self, name: str, value: Any) -> None:
|
|
474
|
+
# try:
|
|
475
|
+
# super().__setattr__(name, value)
|
|
476
|
+
# except AttributeError:
|
|
477
|
+
# setattr(self.model, name, value)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def merge_weights(module: nn.Module):
|
|
481
|
+
"""
|
|
482
|
+
Merges the weights for all `LayerWiseMergedModel` instances within the given module.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
module (nn.Module): The module to process.
|
|
486
|
+
"""
|
|
487
|
+
if isinstance(module, LayerWiseMergedModel):
|
|
488
|
+
module.merge_weights()
|
|
489
|
+
return
|
|
490
|
+
else:
|
|
491
|
+
for submodule in module.children():
|
|
492
|
+
merge_weights(submodule)
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
def merge_and_unload(module: nn.Module):
|
|
496
|
+
"""
|
|
497
|
+
Merges and unloads all `LayerWiseMergedModel` instances within the given module.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
module (nn.Module): The module to process.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
nn.Module: The updated module with merged weights.
|
|
504
|
+
"""
|
|
505
|
+
if isinstance(module, LayerWiseMergedModel):
|
|
506
|
+
return module.merge_and_unload()
|
|
507
|
+
else:
|
|
508
|
+
for name, submodule in module.named_children():
|
|
509
|
+
need_merge = isinstance(submodule, LayerWiseMergedModel)
|
|
510
|
+
submodule = merge_and_unload(submodule)
|
|
511
|
+
if need_merge:
|
|
512
|
+
setattr(module, name, submodule)
|
|
513
|
+
return module
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def fix_other_parts(module: nn.Module):
|
|
517
|
+
"""
|
|
518
|
+
Sets all parameters in the module to not require gradients, except for the merge weights
|
|
519
|
+
in `LayerWiseMergedModel` instances.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
module (nn.Module): The module to process.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
nn.Module: The module with updated parameter requirements.
|
|
526
|
+
"""
|
|
527
|
+
module.requires_grad_(False)
|
|
528
|
+
for submodule in module.modules():
|
|
529
|
+
if isinstance(submodule, LayerWiseMergedModel):
|
|
530
|
+
submodule.merge_weight.requires_grad_(True)
|
|
531
|
+
return module
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
fusion_bench/__init__.py,sha256=68dF-zPvb8E2MgYnmgIJsxIHJBy1MApKeOrRZvQEVlg,421
|
|
2
2
|
fusion_bench/__main__.py,sha256=weUjxpP3ULnDgUxCehdbmoCM9cqfkhDhGB85tAF5qoE,81
|
|
3
3
|
fusion_bench/compat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
fusion_bench/compat/method/__init__.py,sha256=
|
|
4
|
+
fusion_bench/compat/method/__init__.py,sha256=97izLAf4JssNAoOXR4MYffFxb3OEwpHeQeSlL_ihMKI,5566
|
|
5
5
|
fusion_bench/compat/method/base_algorithm.py,sha256=63_AQDj1eJOO6RyTSGXVC6G2DsG8yg9E4pT3RJXgP3A,1952
|
|
6
6
|
fusion_bench/compat/modelpool/AutoModelForSeq2SeqLM.py,sha256=m68BRGy4P-P9lLB10oXOBI-p58a-0FOPcrJ4r4MU32k,1100
|
|
7
7
|
fusion_bench/compat/modelpool/__init__.py,sha256=KD8Ddr9D7rJ5YdHEQsTuNmQ0bgQfqF4l3WNMtHmRHD8,4687
|
|
@@ -15,7 +15,7 @@ fusion_bench/constants/__init__.py,sha256=Pyc4dLbl6oNduOCdnpeXQ9LDyVoIrkdl9eZ_l2
|
|
|
15
15
|
fusion_bench/constants/paths.py,sha256=DVZyQ9FLhkyUdw6ARpXUCAMf_B8hFyJ6UNI-oYly3pE,591
|
|
16
16
|
fusion_bench/dataset/__init__.py,sha256=OJiYmcqz0Vm5O7mE4PB5QFJeL_KjrsseQTRsQATGTm4,1050
|
|
17
17
|
fusion_bench/dataset/clip_dataset.py,sha256=XLpCOiXlLEP3DffAlBn4P2PpUenbEFl-Yk9MNy6nbbI,2790
|
|
18
|
-
fusion_bench/dataset/fer2013.py,sha256=
|
|
18
|
+
fusion_bench/dataset/fer2013.py,sha256=Lub_xVhHfqaiPprvOsDVspJNioh1FjSrkhn3gL_UXDA,404
|
|
19
19
|
fusion_bench/dataset/gpt2_glue.py,sha256=Qq1ZkEIQsTjj8tImvkZDNlduocSYwlEfVrDReZqDWdw,8761
|
|
20
20
|
fusion_bench/dataset/gsm8k.py,sha256=CmANZ0A89PfPwVu_myKhXk1D9IwypOpjH3iqDo1KxcQ,2233
|
|
21
21
|
fusion_bench/dataset/image_dataset.py,sha256=MSZE_UESyRRQDwnkm2KpyIARUg9SWcwqnH4fDNstzS4,1870
|
|
@@ -41,12 +41,16 @@ fusion_bench/dataset/llama/stanford_shp.py,sha256=6ueXKnFXIBBobacU1h5WxGLZrSOtBk
|
|
|
41
41
|
fusion_bench/dataset/llama/ultrachat.py,sha256=Go7WvrDAYnm184fdazHGRYLbSY6Xd7jrESyQeUJtOww,1736
|
|
42
42
|
fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
|
|
43
43
|
fusion_bench/dataset/llama/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
44
|
-
fusion_bench/method/__init__.py,sha256=
|
|
44
|
+
fusion_bench/method/__init__.py,sha256=QGJzdOpZxonu_WUNXSFQIiMy4OHsgqmcU5Bs6OB_RT0,7040
|
|
45
45
|
fusion_bench/method/base_algorithm.py,sha256=5dutGZfPqNhO8F8FOlo3UFR91TZu2Xj7O0pTB40JvWo,1135
|
|
46
46
|
fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
|
|
47
47
|
fusion_bench/method/ensemble.py,sha256=rGxvJTeorfcBuE_e0XO-0-MAc9un7ZCC46ikKGuAcN4,3077
|
|
48
48
|
fusion_bench/method/model_recombination.py,sha256=2tviqmYSPOL0_Ktv8_gt_YzQ4tyCANHxXquUot_3Cgo,5360
|
|
49
49
|
fusion_bench/method/simple_average.py,sha256=2ghcL1E-eLbIYDCHYCoR9WtiYSb1GvFAH163OTTTEEI,4481
|
|
50
|
+
fusion_bench/method/DOGE_TA/DOGE_TA.py,sha256=veNjBfq65fB7oqQL66zAuA339WCY5mG-mefkVteg2-k,13785
|
|
51
|
+
fusion_bench/method/DOGE_TA/__init__.py,sha256=OTukCLUlbCUTDqGBtgBZop7eYFDfU2wjG4PkP4fXN4Q,59
|
|
52
|
+
fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py,sha256=YdQ4trHohW6QzWC2enYvXA44WHxvzmoH_6sMrPn6z60,1305
|
|
53
|
+
fusion_bench/method/DOGE_TA/layer_wise_adamerging.py,sha256=rLk3Nep5d6wMUNCp6q7pC7L0pfBvUwGBIuiGM7CQOf4,9780
|
|
50
54
|
fusion_bench/method/ada_svd/__init__.py,sha256=4XzQbbvE9HI3NtEmEFvo8iC3ds_85vJXe7P7qJfL7kk,77
|
|
51
55
|
fusion_bench/method/ada_svd/clip_vision.py,sha256=QrT6cSwgVEGxXEpVhkvKQVQaoRW5P9V52Y3_8NX0f-o,12556
|
|
52
56
|
fusion_bench/method/adamerging/__init__.py,sha256=nt0saBT_3bqghk-pINQ-XCWm9UWwSZllu4R1sDuAJAA,376
|
|
@@ -65,10 +69,12 @@ fusion_bench/method/analysis/task_vector_cos_similarity.py,sha256=pL-XsWTo258yZT
|
|
|
65
69
|
fusion_bench/method/analysis/task_vector_violin_plot.py,sha256=ie8hPl6QsVz9MQ6C2OEpzIBxQnmVKNf1FPc5bThmQGM,7606
|
|
66
70
|
fusion_bench/method/classification/__init__.py,sha256=emB06UOMDHK5pfQ1WuvLG9Fm0aEEtZxSjpVw8fVE0fM,167
|
|
67
71
|
fusion_bench/method/classification/clip_finetune.py,sha256=DlV1isp8vz6jwXNYQ6zbblAoUfnssL-WBpDeaXI5BVw,15727
|
|
68
|
-
fusion_bench/method/classification/continual_clip_finetune.py,sha256=
|
|
69
|
-
fusion_bench/method/concrete_subspace/__init__.py,sha256=
|
|
72
|
+
fusion_bench/method/classification/continual_clip_finetune.py,sha256=OLhZKS-6aCnafevZkZYcNMKTWDDj3DATB27eZl_i8EY,11530
|
|
73
|
+
fusion_bench/method/concrete_subspace/__init__.py,sha256=jJoFcjnQe-jvccsm9DuCXna378m9XBT9vV1fEZbdfR0,464
|
|
70
74
|
fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py,sha256=90_0HkOIl0XQG89xMa0UiBhrwfV2YqfLxlS04AouR3o,24755
|
|
71
75
|
fusion_bench/method/concrete_subspace/clip_concrete_task_arithmetic.py,sha256=Nx-3AiAeIt5zmcC21Ta2_-4cAQg9hOWvThurXNZzA-w,10580
|
|
76
|
+
fusion_bench/method/concrete_subspace/clip_post_defense.py,sha256=h-c0ioxDopg7pUoRjxx3epqQxVKZAZWz8s7yHjM88mg,32355
|
|
77
|
+
fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py,sha256=eEKKUBgHufYTBaWWxkIKDF0lkuLI2bBgNHVr1JqT41c,35694
|
|
72
78
|
fusion_bench/method/dare/__init__.py,sha256=63Xwkawyl_Ooy4xFxoDlP6wf-rgEWNqPuWTT9-6Ku5o,156
|
|
73
79
|
fusion_bench/method/dare/simple_average.py,sha256=jR08PokPIr5PWSZbGVOp3IApgKvxAIovg3vnB2KiTwk,906
|
|
74
80
|
fusion_bench/method/dare/task_arithmetic.py,sha256=Seno_2BhuogdRxXOni8alnHG-fdW15_OWoAvMoBoJj0,2780
|
|
@@ -85,6 +91,9 @@ fusion_bench/method/fisher_merging/__init__.py,sha256=KWsjrtxKkPYwcUA5rB_6UNIqve
|
|
|
85
91
|
fusion_bench/method/fisher_merging/clip_fisher_merging.py,sha256=QCutGqjkfW3OWETPZsCChqLRAhvfJp4QKD9TGSpTyV0,7635
|
|
86
92
|
fusion_bench/method/fisher_merging/fisher_merging.py,sha256=CPU-tJiDv9FCIBYl7Pn0zA5cdRB1Md5kWchRDlJgly0,20456
|
|
87
93
|
fusion_bench/method/fisher_merging/gpt2_fisher_merging.py,sha256=LZmz41jZ5dSsAHxfOUpr3u2rlCgUPTDR7xMsIlQM-jc,7576
|
|
94
|
+
fusion_bench/method/isotropic_merging/__init__.py,sha256=0mxrl1UIjeFAPQcPcZtbgoCJO-DMW_49GKAhgcG-vEA,585
|
|
95
|
+
fusion_bench/method/isotropic_merging/iso.py,sha256=MwKqfk0oyxqtdOzeSx_9jFXX1a4Rd0WcEPsYvQhBSCg,3773
|
|
96
|
+
fusion_bench/method/isotropic_merging/iso_utils.py,sha256=7L8PYUIJROwHJQmhFY-tdEhkLAnzVKXr-ae55FQ1QSo,6928
|
|
88
97
|
fusion_bench/method/linear/__init__.py,sha256=ChfkoOEAb-rUKwpowFPel-a1hRfS8gCrbnWD-jlRbe4,283
|
|
89
98
|
fusion_bench/method/linear/expo.py,sha256=LCHTWlsPm1Mjhrq0mfpWLVC7skkI9ZksGduy3TxULoU,3939
|
|
90
99
|
fusion_bench/method/linear/linear_interpolation.py,sha256=IONw9BPiRJouY8bE9Abfyz7qVI_1B1n8KGZa0f7Pza8,2157
|
|
@@ -151,7 +160,7 @@ fusion_bench/method/tall_mask/utils.py,sha256=Wlp8WcPwR_lCaBIZ9rgG6ewLfSzz3G7kPk
|
|
|
151
160
|
fusion_bench/method/task_arithmetic/__init__.py,sha256=pSx_NV5Ra_6UXpyYWCi6ANQoAnEtymZt_X1dDN9wT4Y,96
|
|
152
161
|
fusion_bench/method/task_arithmetic/task_arithmetic.py,sha256=1D0uuNtqyA1VS35jh6AnEVsX72HnT02THyerck_lmso,5441
|
|
153
162
|
fusion_bench/method/task_singular_vector/TSVC.py,sha256=yn4SrZNvtA6PoGYJmbmtNeDyDbGnRCgfZ7ZCg914AZU,410
|
|
154
|
-
fusion_bench/method/task_singular_vector/TSVM.py,sha256=
|
|
163
|
+
fusion_bench/method/task_singular_vector/TSVM.py,sha256=H5RzZlQQeF4kZFjuxkz8v3gyVKS3iKPgqNnitKQzbXk,2787
|
|
155
164
|
fusion_bench/method/task_singular_vector/__init__.py,sha256=WMucyl9pu_Ev2kcdrfT4moqMMbzD7hHQVFME5Su5jMA,298
|
|
156
165
|
fusion_bench/method/task_singular_vector/utils/TSVC_utils.py,sha256=FytKbal48EW6iGIA-2zV7QSVbYTVflXr4Mr56q0W75k,2286
|
|
157
166
|
fusion_bench/method/task_singular_vector/utils/TSVM_utils.py,sha256=dsTMQ15zFJ1MPqDOt2TJ01O9Bwq_klyG9xL9hRD2aI0,27521
|
|
@@ -251,6 +260,7 @@ fusion_bench/models/surgery/surgerymodelwrapper.py,sha256=F8jX88K5zVWC6HsfN-nGNk
|
|
|
251
260
|
fusion_bench/models/wrappers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
252
261
|
fusion_bench/models/wrappers/ensemble.py,sha256=wIMZMRyXw5boWAm96c4Tiyebs_HDQovKxpGQ8rLnHUQ,6308
|
|
253
262
|
fusion_bench/models/wrappers/layer_wise_fusion.py,sha256=ZizBGQtSLKOzMLFAhrMNMcv6ZNdvABTyO7M1-DGHh3c,12316
|
|
263
|
+
fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py,sha256=k335dxzq3ezuYkDVOv4ePi128NVyiHVCW6zyuDRTg30,20689
|
|
254
264
|
fusion_bench/models/wrappers/task_wise_fusion.py,sha256=Wn3buQvWw_lihWaKB03_iz34cBPzwBD94kBT6uafWVQ,8404
|
|
255
265
|
fusion_bench/optim/__init__.py,sha256=lemrcuiA6OLjQkpYm-RP-Ox2MgjngN1ywvCo0NgShlM,61
|
|
256
266
|
fusion_bench/optim/exception.py,sha256=fMgo1heiqfGhuI5RIbf30BwWSShn5RQiyeb30QtfTI0,1607
|
|
@@ -462,6 +472,7 @@ fusion_bench_config/method/pwe_moe_ls_for_clip.yaml,sha256=brs9zYeuXfFnnCoRrSaAY
|
|
|
462
472
|
fusion_bench_config/method/simple_average.yaml,sha256=GtMNvt0-qWOevRX2V6fjiYUO2BwDvMw-EcxRMS_PhZQ,53
|
|
463
473
|
fusion_bench_config/method/task_arithmetic.yaml,sha256=TbpAeTwIX48PFOkZU-Ihuu6U9Y5XHZJGDu7vHLt5FjU,74
|
|
464
474
|
fusion_bench_config/method/ties_merging.yaml,sha256=N-XyOTEW0JRtyRJizpHqtb1GEIogUU22XSG76QvIvnw,292
|
|
475
|
+
fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml,sha256=6R9NRuWmj0oapJ_raMB6R6rZPMckt2JtMLrTQ6HhrFc,77
|
|
465
476
|
fusion_bench_config/method/ada_svd/clip_vision.yaml,sha256=KDpDpzuNVqqyyqJcL0q-Ml2A7IUqn_-2dOZXs8zHKlU,184
|
|
466
477
|
fusion_bench_config/method/adamerging/clip.yaml,sha256=fBG7jBBepygKpCbM3fmUeVAr2zzx0g8C21rGGfnEPkA,730
|
|
467
478
|
fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml,sha256=7FPPMf6lcOD2dlNUbb5JyF3pqJ3D2jmvbWAbW9WGn0Y,546
|
|
@@ -474,6 +485,10 @@ fusion_bench_config/method/classification/clip_finetune.yaml,sha256=yWjcdKYaKvy5
|
|
|
474
485
|
fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml,sha256=XsHzr_5NoUZs0Us3eVwP3lUYXYvyJwGEEG9aDI_Z0rU,740
|
|
475
486
|
fusion_bench_config/method/concrete_subspace/clip_concrete_task_arithmetic.yaml,sha256=eNoqcY1iMbs0Y5kKi_ya3rmQQMHqU7ht3EU7G_xmwN0,746
|
|
476
487
|
fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml,sha256=WgTJj28FlVjR0_mCGJC5B8aJa9yezI3QusoXXHOrFoU,739
|
|
488
|
+
fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml,sha256=eGUCntXzDtW0tYX1vij7BHgDWzWq6sz2yFipVZj6z9E,849
|
|
489
|
+
fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml,sha256=DUYOU5A8MQw2cTqbraIDMFC7ciO8RXE2qXgVEEUudLM,891
|
|
490
|
+
fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml,sha256=olDW_p5gyyaynwbGAQgm2ZicYAx9n3i4FprxPecuUsU,923
|
|
491
|
+
fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml,sha256=KLO3C1BdeB6FBKHT0xG4V0OFk7ib2SeMScKeaN5BlsU,863
|
|
477
492
|
fusion_bench_config/method/dare/simple_average.yaml,sha256=oTFSCHul86NTjTtJYK5pNr3tuxW7XxNI-y6fL9Yo4VI,113
|
|
478
493
|
fusion_bench_config/method/dare/task_arithmetic.yaml,sha256=Cvsam89yquamn_GkITT6q8qFKN_Yb5nv8p-XgvnVrgU,134
|
|
479
494
|
fusion_bench_config/method/dare/ties_merging.yaml,sha256=50mPiRkzLN7gxaIs56sPWkAUSvqvdxjQJ8eVl1yUGOg,418
|
|
@@ -484,6 +499,8 @@ fusion_bench_config/method/ensemble/weighted_ensemble.yaml,sha256=U_wQXtogtgiqOT
|
|
|
484
499
|
fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml,sha256=rl7kfVvdo2pG-DnglQUbjzkyBqnq1FpfoSDSjFtdLwk,633
|
|
485
500
|
fusion_bench_config/method/fisher_merging/fisher_merging.yaml,sha256=B1wrv9mhaOID4KcAUEMZNxlvY3tR3Q3UGualFslvx-Y,475
|
|
486
501
|
fusion_bench_config/method/fisher_merging/gpt2_fisher_merging.yaml,sha256=AE7XZqRDj4__J_ipEcjPs7qTB2J3xLQyFRlq1W4iHFE,563
|
|
502
|
+
fusion_bench_config/method/isotropic_merging/iso_c.yaml,sha256=Lh_OtTaUJ08--h85fUr2asF85xLe1NMCK8fVAhHOzdQ,82
|
|
503
|
+
fusion_bench_config/method/isotropic_merging/iso_cts.yaml,sha256=x5vZo__kO8njl4_gFdXnOt15X_qFLv6-diSWHOR4clw,111
|
|
487
504
|
fusion_bench_config/method/linear/expo.yaml,sha256=St3NW6cKVRV3vCn8y0gxQ8k66VTdtsLTEWQTbO9wQ0Y,420
|
|
488
505
|
fusion_bench_config/method/linear/linear_interpolation.yaml,sha256=IQgltk5REITSx8xLuLP11ByPbuMgy7dHz_BrxIgwOas,67
|
|
489
506
|
fusion_bench_config/method/linear/llama_expo.yaml,sha256=SEsC-l5gugY0vlsQkTJqzVgWJnMjFzWuTz814UKbFeM,624
|
|
@@ -515,7 +532,7 @@ fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml,sha256
|
|
|
515
532
|
fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml,sha256=w1OWb38nW08K_hvrRMsCwmRxHWLGQfSSXg5nTiYaP8E,635
|
|
516
533
|
fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml,sha256=J6vYIwqzh95-B3ekDias3FnCrVr4sig4zxpWyvz8hZ0,613
|
|
517
534
|
fusion_bench_config/method/surgery/adamerging_surgery.yaml,sha256=Ne9JlJFgsRYcygBNCOBSN1ygBcLkE6I-8yusfTxyg-Y,826
|
|
518
|
-
fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml,sha256=
|
|
535
|
+
fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml,sha256=CLONjN9TXQ0OQwZHaje0q3WJWxR3LD1b5q5KrWJfZIA,169
|
|
519
536
|
fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml,sha256=mK09Ohsvj0Q6suj5qJM4DyCzRy192QBt4wjHS6W29IY,197
|
|
520
537
|
fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml,sha256=jiAco7M1XO0aekHFZKLKlXL_jRoCA8bgGD44Z7iB208,1001
|
|
521
538
|
fusion_bench_config/method/wemoe/weight_ensembling_moe.yaml,sha256=OEv5yhyUCe5lXeT2PyXC49yrHXEM7i8SZDw6IQRDtAE,620
|
|
@@ -719,9 +736,9 @@ fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_sun397
|
|
|
719
736
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-single-task_svhn.yaml,sha256=2AqMiNCRRunLIrssHvFzu1lUzOaQn8uOHM9yjrQq-_A,109
|
|
720
737
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml,sha256=iQMj2VpDTe_D8OfCo94w5Ud2MON-EGa0DzVr6UmphrA,436
|
|
721
738
|
fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml,sha256=i5Bn8bLl2cgqvrgtIGmoovUfSMehk_m-6C2wwcx5JMU,435
|
|
722
|
-
fusion_bench-0.2.
|
|
723
|
-
fusion_bench-0.2.
|
|
724
|
-
fusion_bench-0.2.
|
|
725
|
-
fusion_bench-0.2.
|
|
726
|
-
fusion_bench-0.2.
|
|
727
|
-
fusion_bench-0.2.
|
|
739
|
+
fusion_bench-0.2.11.dist-info/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
|
|
740
|
+
fusion_bench-0.2.11.dist-info/METADATA,sha256=AYdGcKXZ6BeHCv1piGgpK1yktQqVga-PjUDxS4RYwog,16780
|
|
741
|
+
fusion_bench-0.2.11.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
|
742
|
+
fusion_bench-0.2.11.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
|
|
743
|
+
fusion_bench-0.2.11.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
|
|
744
|
+
fusion_bench-0.2.11.dist-info/RECORD,,
|