fusion-bench 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
  3. fusion_bench/method/DOGE_TA/__init__.py +2 -0
  4. fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
  5. fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
  6. fusion_bench/method/__init__.py +10 -0
  7. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  8. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  9. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  11. fusion_bench/method/isotropic_merging/iso.py +2 -2
  12. fusion_bench/method/task_singular_vector/TSVM.py +3 -3
  13. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
  14. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
  15. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +24 -12
  16. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
  17. fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
  18. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  19. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  20. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  21. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  22. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
  23. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
  24. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,6 @@ Reference:
10
10
  from .iso import (
11
11
  ISO_C_Merge,
12
12
  ISO_CTS_Merge,
13
- IsotropicMergingInCommonSubspace,
14
13
  IsotropicMergingInCommonAndTaskSubspace,
14
+ IsotropicMergingInCommonSubspace,
15
15
  )
@@ -6,11 +6,11 @@ from fusion_bench import BaseAlgorithm, BaseModelPool
6
6
  from fusion_bench.mixins import LightningFabricMixin
7
7
  from fusion_bench.utils.state_dict_arithmetic import (
8
8
  state_dict_add,
9
- state_dict_sub,
10
9
  state_dict_mul,
10
+ state_dict_sub,
11
11
  )
12
12
 
13
- from .iso_utils import iso_c, iso_cts, check_parameterNamesMatch
13
+ from .iso_utils import check_parameterNamesMatch, iso_c, iso_cts
14
14
 
15
15
 
16
16
  class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):
@@ -9,19 +9,19 @@ fusion_bench \
9
9
  ```
10
10
  """
11
11
 
12
- from typing import List, Optional, Union, Iterable
12
+ from typing import Iterable, List, Optional, Union
13
13
 
14
14
  import torch
15
- from torch import Tensor, nn
16
15
  from omegaconf import ListConfig
16
+ from torch import Tensor, nn
17
17
 
18
18
  from fusion_bench import BaseAlgorithm
19
19
  from fusion_bench.mixins import LightningFabricMixin
20
20
  from fusion_bench.utils import timeit_context
21
21
  from fusion_bench.utils.state_dict_arithmetic import (
22
22
  state_dict_add,
23
- state_dict_sub,
24
23
  state_dict_mul,
24
+ state_dict_sub,
25
25
  )
26
26
  from fusion_bench.utils.type import StateDictType
27
27
 
@@ -0,0 +1,531 @@
1
+ import copy
2
+ import functools
3
+ import logging
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
7
+
8
+ import lightning as L
9
+ import torch
10
+ from torch import Tensor, nn
11
+ from torch.func import functional_call
12
+
13
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add
14
+ from fusion_bench.utils.type import StateDictType
15
+
16
+ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
17
+
18
+ log = logging.getLogger(__name__)
19
+
20
+
21
+ def del_attr(obj, names: List[str]):
22
+ """
23
+ Deletes an attribute from an object recursively.
24
+
25
+ Args:
26
+ obj (object): Object to delete attribute from.
27
+ names (list): List of attribute names to delete recursively.
28
+ """
29
+ if len(names) == 1:
30
+ delattr(obj, names[0])
31
+ else:
32
+ del_attr(getattr(obj, names[0]), names[1:])
33
+
34
+
35
+ def set_attr(obj, names: List[str], val):
36
+ """
37
+ Sets an attribute of an object recursively.
38
+
39
+ Args:
40
+ obj (object): Object to set attribute of.
41
+ names (list): List of attribute names to set recursively.
42
+ val (object): Value to set the attribute to.
43
+ """
44
+ if len(names) == 1:
45
+ setattr(obj, names[0], val)
46
+ else:
47
+ set_attr(getattr(obj, names[0]), names[1:], val)
48
+
49
+
50
+ def get_attr(obj, names: List[str]):
51
+ """
52
+ Gets an attribute of an object recursively.
53
+
54
+ Args:
55
+ obj (object): Object to get attribute of.
56
+ names (list): List of attribute names to get recursively.
57
+
58
+ Returns:
59
+ object: The attribute of the object.
60
+ """
61
+ if len(names) == 1:
62
+ return getattr(obj, names[0])
63
+ else:
64
+ return get_attr(getattr(obj, names[0]), names[1:])
65
+
66
+
67
+ def get_layer_wise_weights(
68
+ num_models: int,
69
+ num_layers: int,
70
+ init_values: float = None,
71
+ dtype: torch.dtype = torch.float32,
72
+ ):
73
+ """
74
+ Return a tensor of layer-wise weights for the given number of models and layers.
75
+
76
+ Args:
77
+ num_models (int): The number of models to fuse.
78
+ num_layers (int): The number of layers in each model.
79
+ init_values (float, optional): The initial value for each weight. Defaults to 1.0 / num_models.
80
+ dtype (torch.dtype): dtype of weights. This should be the same with model dtype.
81
+
82
+ Returns:
83
+ Tensor: A tensor of shape (num_models, num_layers) containing the layer-wise weights.
84
+ """
85
+ assert num_models >= 1, f"num_models must be >= 1, got {num_models}"
86
+ assert num_layers >= 1, f"num_layers must be >= 1, got {num_layers}"
87
+ if init_values is None:
88
+ init_values = 1.0 / num_models
89
+ return torch.full((num_models, num_layers), init_values, dtype=dtype)
90
+
91
+
92
+ def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
93
+ """
94
+ Fuse the layer-wise weights with the given state dictionaries.
95
+
96
+ Args:
97
+ layer_wise_weight (Tensor): A tensor of shape (num_models,) containing the layer-wise weights.
98
+ state_dicts (List[Tensor]): A list of state dictionaries, each containing the weights for a single layer.
99
+
100
+ Returns:
101
+ Tensor: A tensor of shape (num_params,) containing the fused weights.
102
+ """
103
+ assert len(layer_wise_weight) == len(
104
+ tensors
105
+ ), f"layer_wise_weight.shape={layer_wise_weight.shape}, len(tensors)={len(tensors)}"
106
+ return sum(
107
+ layer_wise_weight[i] * w.to(layer_wise_weight.device)
108
+ for i, w in enumerate(tensors)
109
+ )
110
+
111
+
112
+ def fuse_weights(
113
+ layer_wise_weight: Tensor, state_dicts: List[StateDictType]
114
+ ) -> StateDictType:
115
+ """
116
+ Fuse the weights of multiple models using layer-wise fusion.
117
+
118
+ Args:
119
+ layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
120
+ state_dicts (List[StateDict]): A list of state dictionaries, one for each model.
121
+
122
+ Returns:
123
+ A dictionary mapping each weight tensor key to the fused weight tensor.
124
+ """
125
+ num_models = len(state_dicts)
126
+ num_layers = len(state_dicts[0])
127
+ assert layer_wise_weight.shape == (
128
+ num_models,
129
+ num_layers,
130
+ ), f"layer_wise_weight.shape={layer_wise_weight.shape}, expected (num_models, num_layers): ({num_models}, {num_layers})"
131
+ return {
132
+ k: _fuse_weights(
133
+ layer_wise_weight[:, i], [state_dict[k] for state_dict in state_dicts]
134
+ )
135
+ for i, k in enumerate(state_dicts[0].keys())
136
+ }
137
+
138
+
139
+ class LayerWiseMergedModel(nn.Module):
140
+ _merged_state_dict: StateDictType = None
141
+
142
+ def __init__(
143
+ self,
144
+ layer_wise_weight: Tensor,
145
+ pretrained_model: nn.Module,
146
+ finetuned_models: List[nn.Module],
147
+ clamp_weights: bool = True,
148
+ tie_weights: bool = False,
149
+ strict: bool = True,
150
+ sparsity_ratio: Optional[float] = None,
151
+ normalized_merging_weights: bool = False,
152
+ ):
153
+ R"""
154
+ This class wraps a pretrained model and a list of finetuned models, and merges the weights of the finetuned models into the pretrained model using layer-wise fusion.
155
+
156
+ Args:
157
+ layer_wise_weight (Tensor): A tensor of shape (num_models, num_layers) representing the weight of each layer for each model.
158
+ pretrained_model (nn.Module): The pretrained model to merge the weights into.
159
+ finetuned_models (List[nn.Module]): A list of finetuned models to merge the weights from. This should have the same architecture as the pretrained model. We use these models to compute the task vectors.
160
+ clamp_weights (bool, optional): If True, the layer-wise weights will be clamped to [0, 1]. Defaults to True.
161
+ tie_weights (bool, optional): This option passes the `tie_weights` argument to the `functional_call` function. Defaults to False.
162
+ strict (bool, optional): This option passes the `strict` argument to the `functional_call` function. Defaults to True.
163
+ sparsity_ratio (float, optional): If `sparsity_ratio` is provided, the task vector will be pruned before merging. A high spasity level can save the memory usage during merging.
164
+ normalized_merging_weights (bool, optional): If True, the layer-wise weights will be normalized for each layer, so that the sum of weights across models for each layer is 1. Defaults to False.
165
+ """
166
+ super().__init__()
167
+ if torch.cuda.is_available():
168
+ self._fabric = L.Fabric(devices=1)
169
+ self._fabric.launch()
170
+ self.clamp_weights = clamp_weights
171
+ self.tie_weights = tie_weights
172
+ self.strict = strict
173
+ self.sparsity_ratio = sparsity_ratio
174
+ self.nromalized_merging_weights = normalized_merging_weights
175
+
176
+ pretrained_sd = pretrained_model.state_dict(keep_vars=True)
177
+ filtered_keys = [
178
+ k
179
+ for k in pretrained_sd.keys()
180
+ if ("encoder" in k and "layer_norm" not in k and "weight" in k)
181
+ ]
182
+ self.merge_weight = nn.Parameter(
183
+ layer_wise_weight[:, : len(filtered_keys)], requires_grad=True
184
+ )
185
+ task_vectors = []
186
+ for m in finetuned_models:
187
+ m.requires_grad_(False)
188
+ self.pretrained_model = pretrained_model.requires_grad_(False)
189
+ for model in finetuned_models:
190
+ model_sd = model.state_dict(keep_vars=True)
191
+ filtered_task_vector = {
192
+ k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
193
+ }
194
+ if self._fabric is not None:
195
+ filtered_task_vector = self._fabric.to_device(filtered_task_vector)
196
+ task_vectors.append(filtered_task_vector)
197
+
198
+ self.projection = {}
199
+ for layer_name in task_vectors[0].keys():
200
+ for i, vector in enumerate(task_vectors):
201
+ layer_vector = vector[layer_name]
202
+ u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
203
+ if i == 0:
204
+ print(f"Computed SVD for {layer_name}...")
205
+ sum_u = torch.zeros_like(u, device=layer_vector.device)
206
+ sum_s = torch.zeros_like(s, device=layer_vector.device)
207
+ sum_v = torch.zeros_like(v, device=layer_vector.device)
208
+
209
+ reduced_index_s = int(s.shape[0] / len(task_vectors))
210
+
211
+ # select only the first reduced_index_s columns of u and place them
212
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
213
+ :, :reduced_index_s
214
+ ]
215
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
216
+ :reduced_index_s
217
+ ]
218
+ # select only the first reduced_index_s rows of v and place them
219
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
220
+ :reduced_index_s, :
221
+ ]
222
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
223
+ # u_v, s_v, v_v = torch.linalg.svd(sum_v, full_matrices=False)
224
+ layer_proj = torch.matmul(
225
+ u_u[:, : int(s.shape[0] / len(task_vectors))],
226
+ u_u[:, : int(s.shape[0] / len(task_vectors))].T,
227
+ )
228
+ self.projection[layer_name] = layer_proj
229
+
230
+ self.delta = [
231
+ {
232
+ k: torch.zeros_like(v).clone().requires_grad_()
233
+ for k, v in task_vector.items()
234
+ }
235
+ for task_vector in task_vectors
236
+ ]
237
+ if self._fabric is not None:
238
+ self.delta = self._fabric.to_device(self.delta)
239
+ self.lamdas = self.compute_layer_lamdas(task_vectors)
240
+
241
+ for layer_name in task_vectors[0].keys():
242
+ optimizer = torch.optim.Adam(
243
+ [delta[layer_name] for delta in self.delta], lr=1e-4
244
+ )
245
+ layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors])
246
+ layer_lamdas = torch.stack([lamdas[layer_name] for lamdas in self.lamdas])
247
+ for _ in range(400):
248
+ optimizer.zero_grad()
249
+ layer_delta = torch.stack([de[layer_name] for de in self.delta])
250
+ loss = self.taskvector_loss(layer_vectors, layer_delta, layer_lamdas)
251
+ print(f"Epoch: {_}, Layer: {layer_name}, Loss: {loss.item()}")
252
+ self._fabric.backward(loss)
253
+ for delta in self.delta:
254
+ grad_proj = (
255
+ self.projection[layer_name] @ delta[layer_name].grad.detach()
256
+ )
257
+ delta[layer_name].grad.data = delta[layer_name].grad.data.sub_(
258
+ grad_proj
259
+ )
260
+ optimizer.step()
261
+ for delta in self.delta:
262
+ for param in delta.values():
263
+ param.grad = None
264
+ del self.projection
265
+ self.delta = [
266
+ {key: param.detach().cpu() for key, param in delta.items()}
267
+ for delta in self.delta
268
+ ]
269
+ self.lamdas = [
270
+ {key: param.cpu() for key, param in lamdas.items()}
271
+ for lamdas in self.lamdas
272
+ ]
273
+ task_vectors = [
274
+ {k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
275
+ ]
276
+ flat_vectors = []
277
+ vector_masks = []
278
+ for idx, task_vector in enumerate(task_vectors):
279
+ flat_vector = self.state_dict_to_vector(task_vector)
280
+ vector_mask = self.topk_values_mask(flat_vector, K=30)
281
+ flat_vectors.append(flat_vector)
282
+ vector_masks.append(vector_mask)
283
+ flat_deltas = [self.state_dict_to_vector(delta) for delta in self.delta]
284
+ self.task_vectors = [
285
+ self.vector_to_state_dict(
286
+ (flat_vector + flat_delta) * vector_mask, self.delta[0]
287
+ )
288
+ for flat_vector, flat_delta, vector_mask in zip(
289
+ flat_vectors, flat_deltas, vector_masks
290
+ )
291
+ ]
292
+ if self._fabric is not None:
293
+ self.task_vectors = self._fabric.to_device(self.task_vectors)
294
+
295
+ # if `sparisty_ratio` is given, pruning the task vectors.
296
+ if sparsity_ratio is not None:
297
+ from fusion_bench.method.pruning.prune_utils import (
298
+ unstructured_magnitude_prune_,
299
+ )
300
+
301
+ for name, param in self.task_vectors.named_parameters():
302
+ if param.dim() != 2:
303
+ continue
304
+ print(f"pruning {name}")
305
+ pruned_param = unstructured_magnitude_prune_(
306
+ param.data.clone(), torch.abs, sparsity_ratio=sparsity_ratio
307
+ )
308
+ set_attr(
309
+ self.task_vectors,
310
+ name.split("."),
311
+ nn.Parameter(pruned_param.to_sparse(), requires_grad=False),
312
+ )
313
+
314
+ def topk_values_mask(self, M, K):
315
+ if K > 1:
316
+ K /= 100
317
+
318
+ original_shape = M.shape
319
+ if M.dim() == 1:
320
+ M = M.unsqueeze(0)
321
+
322
+ n, d = M.shape
323
+ k = int(d * K)
324
+ k = d - k # Keep top k elements instead of bottom k elements
325
+
326
+ # Find the k-th smallest element by magnitude for each row
327
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
328
+ # Create a mask tensor with True for the top k elements in each row
329
+ mask = M.abs() >= kth_values
330
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
331
+
332
+ return final_mask
333
+
334
+ def state_dict_to_vector(self, state_dict, remove_keys=[]):
335
+ """
336
+ Convert a state dictionary to a vector, removing specified keys.
337
+
338
+ Args:
339
+ state_dict (dict): The state dictionary to convert.
340
+ remove_keys (list): List of keys to remove from the state dictionary.
341
+
342
+ Returns:
343
+ Tensor: A vector representation of the state dictionary.
344
+ """
345
+ shared_state_dict = copy.deepcopy(state_dict)
346
+ for key in remove_keys:
347
+ if key in shared_state_dict:
348
+ del shared_state_dict[key]
349
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
350
+ return nn.utils.parameters_to_vector(
351
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
352
+ )
353
+
354
+ def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
355
+ """
356
+ Convert a vector back to a state dictionary, removing specified keys.
357
+
358
+ Args:
359
+ vector (Tensor): The vector to convert.
360
+ state_dict (dict): The reference state dictionary.
361
+ remove_keys (list): List of keys to remove from the state dictionary.
362
+
363
+ Returns:
364
+ dict: A state dictionary representation of the vector.
365
+ """
366
+ # create a reference dict to define the order of the vector
367
+ reference_dict = copy.deepcopy(state_dict)
368
+ for key in remove_keys:
369
+ if key in reference_dict:
370
+ del reference_dict[key]
371
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
372
+
373
+ # create a shared state dict using the reference dict
374
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
375
+
376
+ # add back the encoder and decoder embedding weights.
377
+ if "transformer.shared.weight" in sorted_reference_dict:
378
+ for key in remove_keys:
379
+ sorted_reference_dict[key] = sorted_reference_dict[
380
+ "transformer.shared.weight"
381
+ ]
382
+ return sorted_reference_dict
383
+
384
+ def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
385
+ """
386
+ Computes the loss based on delta and task vectors for a specific layer.
387
+ """
388
+ total_loss = 0.0
389
+
390
+ layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
391
+ sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
392
+
393
+ layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
394
+ sum_over_delta = layer_delta_scale.sum(dim=0)
395
+
396
+ # Iterate through each vector and calculate the loss one by one
397
+ for v_j in layer_vectors:
398
+ part1 = -v_j * sum_over_num_vectors
399
+ part2 = -v_j * sum_over_delta
400
+ part3 = v_j * v_j
401
+
402
+ expression = part1 + part2 + part3
403
+ layer_loss = expression.sum(dim=1).pow(2).sum()
404
+
405
+ # Cumulative total loss
406
+ total_loss += layer_loss
407
+ return total_loss
408
+
409
+ def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
410
+ lamdas = []
411
+ for vec in vectors:
412
+ tmp = {}
413
+ for layer_name in vec.keys():
414
+ norm_vec = torch.norm(vec[layer_name])
415
+ tmp[layer_name] = 0.07 / norm_vec
416
+ lamdas.append(tmp)
417
+ return lamdas
418
+
419
+ @property
420
+ def forward_model(self):
421
+ return functools.partial(
422
+ functional_call,
423
+ self.pretrained_model,
424
+ self._merged_state_dict,
425
+ tie_weights=self.tie_weights,
426
+ strict=self.strict,
427
+ )
428
+
429
+ def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
430
+ self.merge_weights(task_vector_mask=task_vector_mask)
431
+ self.pretrained_model.load_state_dict(self._merged_state_dict)
432
+ return self.pretrained_model
433
+
434
+ def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
435
+ """
436
+ Merges the weights of the model.
437
+ Call this after each update step.
438
+ """
439
+ if self.clamp_weights:
440
+ layer_wise_weight = self.merge_weight.clamp(0, 1)
441
+ else:
442
+ layer_wise_weight = self.merge_weight
443
+ if self.nromalized_merging_weights:
444
+ # normalize the weights for each layer, so that the sum of weights across models for each layer is 1.
445
+ layer_wise_weight = layer_wise_weight.softmax(dim=0)
446
+
447
+ state_dict = self.pretrained_model.state_dict(keep_vars=True)
448
+ # shape of layer_wise_weight: (num_models, num_layers)
449
+ for weight, task_vector in zip(layer_wise_weight, self.task_vectors):
450
+ task_vector_items = list(task_vector.items())
451
+ for w, (name, param) in zip(weight, task_vector_items):
452
+ state_dict[name] = state_dict[name] + param * w
453
+ self._merged_state_dict = state_dict
454
+
455
+ return state_dict
456
+
457
+ def forward(self, *args, **kwargs):
458
+ if self._merged_state_dict is None:
459
+ self.merge_weights()
460
+ return self.forward_model(args=args, kwargs=kwargs)
461
+
462
+ # def __getattr__(self, name: str) -> Any:
463
+ # try:
464
+ # return super().__getattr__(name)
465
+ # except AttributeError:
466
+ # attr = getattr(self.model, name)
467
+ # if isinstance(attr, Callable):
468
+ # warnings.warn(
469
+ # f"forwarding `{name}` to the underlying model", UserWarning
470
+ # )
471
+ # return attr
472
+
473
+ # def __setattr__(self, name: str, value: Any) -> None:
474
+ # try:
475
+ # super().__setattr__(name, value)
476
+ # except AttributeError:
477
+ # setattr(self.model, name, value)
478
+
479
+
480
+ def merge_weights(module: nn.Module):
481
+ """
482
+ Merges the weights for all `LayerWiseMergedModel` instances within the given module.
483
+
484
+ Args:
485
+ module (nn.Module): The module to process.
486
+ """
487
+ if isinstance(module, LayerWiseMergedModel):
488
+ module.merge_weights()
489
+ return
490
+ else:
491
+ for submodule in module.children():
492
+ merge_weights(submodule)
493
+
494
+
495
+ def merge_and_unload(module: nn.Module):
496
+ """
497
+ Merges and unloads all `LayerWiseMergedModel` instances within the given module.
498
+
499
+ Args:
500
+ module (nn.Module): The module to process.
501
+
502
+ Returns:
503
+ nn.Module: The updated module with merged weights.
504
+ """
505
+ if isinstance(module, LayerWiseMergedModel):
506
+ return module.merge_and_unload()
507
+ else:
508
+ for name, submodule in module.named_children():
509
+ need_merge = isinstance(submodule, LayerWiseMergedModel)
510
+ submodule = merge_and_unload(submodule)
511
+ if need_merge:
512
+ setattr(module, name, submodule)
513
+ return module
514
+
515
+
516
+ def fix_other_parts(module: nn.Module):
517
+ """
518
+ Sets all parameters in the module to not require gradients, except for the merge weights
519
+ in `LayerWiseMergedModel` instances.
520
+
521
+ Args:
522
+ module (nn.Module): The module to process.
523
+
524
+ Returns:
525
+ nn.Module: The module with updated parameter requirements.
526
+ """
527
+ module.requires_grad_(False)
528
+ for submodule in module.modules():
529
+ if isinstance(submodule, LayerWiseMergedModel):
530
+ submodule.merge_weight.requires_grad_(True)
531
+ return module
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: fusion_bench
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  License: MIT License