fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/dataset/fer2013.py +0 -1
  3. fusion_bench/method/__init__.py +10 -0
  4. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  5. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  6. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  7. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  8. fusion_bench/method/doge_ta/__init__.py +2 -0
  9. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
  10. fusion_bench/method/doge_ta/doge_ta.py +364 -0
  11. fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
  12. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  13. fusion_bench/method/isotropic_merging/iso.py +2 -2
  14. fusion_bench/method/opcm/opcm.py +93 -84
  15. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  16. fusion_bench/method/opcm/ties_merging.py +71 -52
  17. fusion_bench/method/task_singular_vector/TSVM.py +3 -3
  18. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  19. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
  20. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
  21. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
  22. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
  23. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  24. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  25. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  26. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  27. fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
  28. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
  29. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
  30. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
  31. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
  32. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2 @@
1
+ # flake8: noqa F401
2
+ from .doge_ta import DOGE_TA_Algorithm
@@ -0,0 +1,46 @@
1
+ """
2
+ Example Usage:
3
+
4
+ ```bash
5
+ fusion_bench \
6
+ method=adamerging \
7
+ method.name=clip_layer_wise_adamerging \
8
+ method.save_merging_weights=merging_weights.pt \
9
+ modelpool=clip-vit-base-patch32_TA8 \
10
+ taskpool=clip-vit-classification_TA8 \
11
+ fabric.loggers.root_dir=outputs/logs/ViT-B-32 \
12
+ fabric.loggers.name=clip_layer_wise_adamerging_adamerging
13
+ ```
14
+ """
15
+
16
+ import functools
17
+ import logging
18
+
19
+ from torch.utils.data import DataLoader
20
+
21
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
22
+ from fusion_bench.mixins import CLIPClassificationMixin
23
+ from fusion_bench.utils.data import InfiniteDataLoader
24
+
25
+ from .layer_wise_adamerging import LayerWiseAdaMergingAlgorithm
26
+
27
+ log = logging.getLogger(__name__)
28
+
29
+
30
+ class CLIPLayerWiseAdaMergingAlgorithm(
31
+ CLIPClassificationMixin,
32
+ LayerWiseAdaMergingAlgorithm,
33
+ ):
34
+ def on_test_time_adaptation_start(self):
35
+ """
36
+ Here we load the CLIP processor and construct the zero-shot classification head for each task.
37
+ """
38
+ self.setup_zero_shot_classification_head()
39
+
40
+ @functools.cache
41
+ def get_shuffled_test_loader_iter(self, task: str):
42
+ return super().get_shuffled_test_loader_iter(
43
+ task,
44
+ batch_size=self.config.batch_size,
45
+ num_workers=self.config.num_workers,
46
+ )
@@ -0,0 +1,364 @@
1
+ R"""
2
+ This script contains the general implementation of Modeling Multi-Task Model Merging as Adaptive Projective Gradient Descent.
3
+
4
+ https://arxiv.org/abs/2501.01230
5
+
6
+ Example Usage:
7
+
8
+ ```bash
9
+ fusion_bench \
10
+ method=doge_ta/doge_ta \
11
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
12
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
13
+
14
+ fusion_bench \
15
+ method=adamerging \
16
+ method.name=clip_layer_wise_adamerging_doge_ta \
17
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8 \
18
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
19
+ ```
20
+ """
21
+
22
+ import copy
23
+ import logging
24
+ import time
25
+ from collections import OrderedDict
26
+ from copy import deepcopy
27
+ from functools import reduce
28
+ from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
29
+
30
+ import lightning as L
31
+ import torch
32
+ from torch import nn
33
+
34
+ from fusion_bench.method.base_algorithm import BaseAlgorithm
35
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
36
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
37
+ from fusion_bench.modelpool import BaseModelPool
38
+ from fusion_bench.utils.state_dict_arithmetic import (
39
+ state_dict_add,
40
+ state_dict_mul,
41
+ state_dict_sub,
42
+ )
43
+ from fusion_bench.utils.type import StateDictType
44
+
45
+ log = logging.getLogger(__name__)
46
+
47
+
48
+ class DOGE_TA_Algorithm(
49
+ BaseAlgorithm,
50
+ SimpleProfilerMixin,
51
+ LightningFabricMixin,
52
+ ):
53
+ """
54
+ Task Arithmetic Algorithm for model fusion with learnable delta.
55
+
56
+ This class extends the Task Arithmetic method to include a learnable delta
57
+ for task vectors, optimized to maximize cosine similarity among the task vectors.
58
+
59
+ Attributes:
60
+ scaling_factor (int): The factor by which the task vectors will be scaled before merging.
61
+ delta (StateDictType): A learnable parameter to adjust task vectors, initialized as zeros.
62
+ """
63
+
64
+ _config_mapping = BaseAlgorithm._config_mapping | {
65
+ "subspace": "subspace",
66
+ "K": "K",
67
+ "lamda": "lamda",
68
+ }
69
+
70
+ def __init__(self, subspace, K, lamda):
71
+ self.delta = None # Initialize delta as None; will be set during run
72
+ self.subspace = subspace
73
+ self.K = K
74
+ self.lamda = lamda
75
+ super().__init__()
76
+
77
+ @property
78
+ def device(self) -> torch.device:
79
+ return self.fabric.device
80
+
81
+ @torch.no_grad()
82
+ def compute_task_vectors(
83
+ self, modelpool: BaseModelPool, pretrained_model: nn.Module
84
+ ) -> List[StateDictType]:
85
+ """
86
+ Computes task vectors for each model in the model pool relative to the pretrained model.
87
+ """
88
+ task_vectors = []
89
+ pretrained_sd = pretrained_model.state_dict(keep_vars=True)
90
+ filtered_keys = [
91
+ k
92
+ for k in pretrained_sd.keys()
93
+ if ("encoder" in k and "layer_norm" not in k and "weight" in k)
94
+ ] # Flan T5: "layer_norm" not in k and ("q.weight" in k or "v.weight" in k)
95
+
96
+ for model_name in modelpool.model_names:
97
+ model = modelpool.load_model(model_name)
98
+ model_sd = model.state_dict(keep_vars=True)
99
+
100
+ filtered_task_vector = {
101
+ k: (model_sd[k] - pretrained_sd[k]) for k in filtered_keys
102
+ }
103
+ task_vectors.append(filtered_task_vector)
104
+
105
+ return task_vectors
106
+
107
+ def taskvector_loss(self, layer_vectors, layer_delta, layer_lamdas) -> torch.Tensor:
108
+ """
109
+ Computes the loss based on delta and task vectors for a specific layer.
110
+ """
111
+ total_loss = 0.0
112
+
113
+ layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
114
+ sum_over_num_vectors = layer_vectors_scale.sum(dim=0)
115
+
116
+ layer_delta_scale = layer_delta.unsqueeze(0) * layer_lamdas.view(-1, 1, 1)
117
+ sum_over_delta = layer_delta_scale.sum(dim=0)
118
+
119
+ # Iterate through each vector and calculate the loss one by one
120
+ for v_j in layer_vectors:
121
+ part1 = -v_j * sum_over_num_vectors
122
+ part2 = -v_j * sum_over_delta
123
+ part3 = v_j * v_j
124
+
125
+ expression = part1 + part2 + part3
126
+ layer_loss = expression.sum(dim=1).pow(2).sum()
127
+
128
+ # Cumulative total loss
129
+ total_loss += layer_loss
130
+ return total_loss
131
+
132
+ @torch.enable_grad()
133
+ def optimize_delta(self, task_vectors: List[StateDictType]) -> None:
134
+ """
135
+ Optimizes the delta based on the loss of task vectors.
136
+ """
137
+ if self.delta is None:
138
+ self.delta = {
139
+ k: nn.Parameter(torch.zeros_like(v, device=self.device).detach())
140
+ for k, v in task_vectors[0].items()
141
+ }
142
+
143
+ optimizer = torch.optim.Adam(self.delta.values(), lr=1e-4)
144
+ initial_mem = torch.cuda.memory_allocated()
145
+ start_time = time.time()
146
+ for layer_name in task_vectors[0].keys():
147
+ layer_vectors = torch.stack([vec[layer_name] for vec in task_vectors]).to(
148
+ self.device
149
+ )
150
+ layer_lamdas = torch.stack(
151
+ [lamdas[layer_name] for lamdas in self.lamdas]
152
+ ).to(self.device)
153
+ for _ in range(400):
154
+ optimizer.zero_grad()
155
+ loss = self.taskvector_loss(
156
+ layer_vectors, self.delta[layer_name], layer_lamdas
157
+ )
158
+ self.fabric.backward(loss)
159
+ grad_proj = (
160
+ self.projection[layer_name] @ self.delta[layer_name].grad.detach()
161
+ )
162
+ self.delta[layer_name].grad.data = self.delta[
163
+ layer_name
164
+ ].grad.data.sub_(grad_proj)
165
+ optimizer.step()
166
+ self.delta[layer_name].grad = None
167
+ end_time = time.time()
168
+ print(f"Running time: {end_time - start_time} s")
169
+ final_mem = torch.cuda.memory_allocated()
170
+ print(f"Memory usage: {(final_mem - initial_mem) / (1024 ** 2)} MB")
171
+ print("Optimization completed.")
172
+
173
+ @torch.no_grad()
174
+ def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
175
+ """
176
+ Runs the Algorithm with learnable delta to fuse models in the given model pool.
177
+
178
+ Args:
179
+ modelpool (Union[BaseModelPool, Dict[str, nn.Module]]): The pool of models to fuse.
180
+
181
+ Returns:
182
+ nn.Module: The pre-trained model with the merged task vectors after optimizing delta.
183
+ """
184
+ if not isinstance(modelpool, BaseModelPool):
185
+ modelpool = BaseModelPool(modelpool)
186
+
187
+ log.info("Fusing models using DOGE_TA with learnable delta.")
188
+ with self.profile("load model"):
189
+ pretrained_model = modelpool.load_model("_pretrained_")
190
+
191
+ task_vectors = self.compute_task_vectors(modelpool, pretrained_model)
192
+
193
+ self.lamdas = self.compute_layer_lamdas(task_vectors)
194
+ self.projection = {}
195
+ for layer_name in task_vectors[0].keys():
196
+ for i, vector in enumerate(task_vectors):
197
+ layer_vector = vector[layer_name].to(self.device)
198
+ u, s, v = torch.linalg.svd(layer_vector, full_matrices=False)
199
+ if i == 0:
200
+ print(f"Computed SVD for {layer_name}...")
201
+ sum_u = torch.zeros_like(u, device=layer_vector.device)
202
+ sum_s = torch.zeros_like(s, device=layer_vector.device)
203
+ sum_v = torch.zeros_like(v, device=layer_vector.device)
204
+
205
+ reduced_index_s = int(s.shape[0] / len(task_vectors))
206
+
207
+ # select only the first reduced_index_s columns of u and place them
208
+ sum_u[:, i * reduced_index_s : (i + 1) * reduced_index_s] = u[
209
+ :, :reduced_index_s
210
+ ]
211
+ sum_s[i * reduced_index_s : (i + 1) * reduced_index_s] = s[
212
+ :reduced_index_s
213
+ ]
214
+ # select only the first reduced_index_s rows of v and place them
215
+ sum_v[i * reduced_index_s : (i + 1) * reduced_index_s, :] = v[
216
+ :reduced_index_s, :
217
+ ]
218
+ u_u, s_u, v_u = torch.linalg.svd(sum_u, full_matrices=False)
219
+ layer_proj = torch.matmul(
220
+ u_u[:, : int(s.shape[0] / self.config.subspace)],
221
+ u_u[:, : int(s.shape[0] / self.config.subspace)].T,
222
+ )
223
+ self.projection[layer_name] = layer_proj
224
+
225
+ self.optimize_delta(task_vectors)
226
+
227
+ del self.projection
228
+ self.delta = {key: param.detach().cpu() for key, param in self.delta.items()}
229
+ self.lamdas = [
230
+ {key: param.cpu() for key, param in lamdas.items()}
231
+ for lamdas in self.lamdas
232
+ ]
233
+ task_vectors = [
234
+ {k: v.cpu() for k, v in task_vector.items()} for task_vector in task_vectors
235
+ ]
236
+ flat_vectors = []
237
+ vector_masks = []
238
+ for idx, task_vector in enumerate(task_vectors):
239
+ flat_vector = self.state_dict_to_vector(task_vector)
240
+ vector_mask = self.topk_values_mask(flat_vector, K=self.config.K)
241
+ flat_vectors.append(flat_vector)
242
+ vector_masks.append(vector_mask)
243
+ flat_delta = self.state_dict_to_vector(self.delta)
244
+
245
+ adjusted_vectors = [
246
+ self.vector_to_state_dict(
247
+ (flat_vector + flat_delta) * vector_mask, self.delta
248
+ )
249
+ for flat_vector, vector_mask in zip(flat_vectors, vector_masks)
250
+ ]
251
+
252
+ for layer_name in adjusted_vectors[0].keys():
253
+ layer_vectors = torch.stack(
254
+ [vec[layer_name] for vec in adjusted_vectors], dim=0
255
+ )
256
+ layer_lamdas = torch.stack(
257
+ [lamdas[layer_name] for lamdas in self.lamdas], dim=0
258
+ )
259
+ layer_vectors_scale = layer_vectors * layer_lamdas.view(-1, 1, 1)
260
+ task_vectors[0][layer_name] = layer_vectors_scale.sum(dim=0)
261
+
262
+ final_state_dict = {}
263
+ pretrained_sd = pretrained_model.state_dict(keep_vars=True)
264
+ for k, v in pretrained_sd.items():
265
+ if k in task_vectors[0]:
266
+ final_state_dict[k] = v + task_vectors[0][k]
267
+ else:
268
+ final_state_dict[k] = v
269
+
270
+ pretrained_model.load_state_dict(final_state_dict)
271
+
272
+ self.print_profile_summary()
273
+ return pretrained_model
274
+
275
+ def compute_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
276
+ lamdas = []
277
+ for vec in vectors:
278
+ norm_vec = torch.norm(
279
+ torch.cat([param.flatten() for param in vec.values()])
280
+ )
281
+ # norm_vec = sum([torch.norm(param) for param in vec.values()])
282
+ lamdas.append(self.config.lamda / norm_vec)
283
+ print(lamdas)
284
+ return lamdas
285
+
286
+ def compute_layer_lamdas(self, vectors: List[StateDictType]) -> torch.Tensor:
287
+ lamdas = []
288
+ for vec in vectors:
289
+ tmp = {}
290
+ for layer_name in vec.keys():
291
+ norm_vec = torch.norm(vec[layer_name])
292
+ tmp[layer_name] = self.config.lamda / norm_vec
293
+ lamdas.append(tmp)
294
+ return lamdas
295
+
296
+ def topk_values_mask(self, M, K):
297
+ if K > 1:
298
+ K /= 100
299
+
300
+ original_shape = M.shape
301
+ if M.dim() == 1:
302
+ M = M.unsqueeze(0)
303
+
304
+ n, d = M.shape
305
+ k = int(d * K)
306
+ k = d - k # Keep top k elements instead of bottom k elements
307
+
308
+ # Find the k-th smallest element by magnitude for each row
309
+ kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
310
+ # Create a mask tensor with True for the top k elements in each row
311
+ mask = M.abs() >= kth_values
312
+ final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
313
+
314
+ return final_mask
315
+
316
+ def state_dict_to_vector(self, state_dict, remove_keys=[]):
317
+ """
318
+ Convert a state dictionary to a vector, removing specified keys.
319
+
320
+ Args:
321
+ state_dict (dict): The state dictionary to convert.
322
+ remove_keys (list): List of keys to remove from the state dictionary.
323
+
324
+ Returns:
325
+ Tensor: A vector representation of the state dictionary.
326
+ """
327
+ shared_state_dict = copy.deepcopy(state_dict)
328
+ for key in remove_keys:
329
+ if key in shared_state_dict:
330
+ del shared_state_dict[key]
331
+ sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
332
+ return nn.utils.parameters_to_vector(
333
+ [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
334
+ )
335
+
336
+ def vector_to_state_dict(self, vector, state_dict, remove_keys=[]):
337
+ """
338
+ Convert a vector back to a state dictionary, removing specified keys.
339
+
340
+ Args:
341
+ vector (Tensor): The vector to convert.
342
+ state_dict (dict): The reference state dictionary.
343
+ remove_keys (list): List of keys to remove from the state dictionary.
344
+
345
+ Returns:
346
+ dict: A state dictionary representation of the vector.
347
+ """
348
+ # create a reference dict to define the order of the vector
349
+ reference_dict = copy.deepcopy(state_dict)
350
+ for key in remove_keys:
351
+ if key in reference_dict:
352
+ del reference_dict[key]
353
+ sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
354
+
355
+ # create a shared state dict using the reference dict
356
+ nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
357
+
358
+ # add back the encoder and decoder embedding weights.
359
+ if "transformer.shared.weight" in sorted_reference_dict:
360
+ for key in remove_keys:
361
+ sorted_reference_dict[key] = sorted_reference_dict[
362
+ "transformer.shared.weight"
363
+ ]
364
+ return sorted_reference_dict
@@ -0,0 +1,250 @@
1
+ import logging
2
+ import os
3
+ from abc import abstractmethod
4
+ from typing import TYPE_CHECKING, Any, List, Mapping, TypeVar, Union, cast # noqa: F401
5
+
6
+ import torch
7
+ from lightning.fabric.utilities.rank_zero import rank_zero_only
8
+ from omegaconf import DictConfig
9
+ from torch import Tensor, nn
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.autonotebook import tqdm
12
+
13
+ from fusion_bench.compat.method import ModelFusionAlgorithm
14
+ from fusion_bench.compat.modelpool import ModelPool
15
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
16
+ from fusion_bench.method.adamerging.utils import get_memory_usage
17
+ from fusion_bench.mixins.lightning_fabric import LightningFabricMixin
18
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
19
+ from fusion_bench.models.wrappers.layer_wise_fusion_doge_ta import (
20
+ LayerWiseMergedModel,
21
+ get_layer_wise_weights,
22
+ )
23
+ from fusion_bench.utils.data import load_tensor_from_file
24
+ from fusion_bench.utils.type import TorchModelType
25
+
26
+ if TYPE_CHECKING:
27
+ from fusion_bench.programs.fabric_fusion_program import FabricModelFusionProgram
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+
32
+ class LayerWiseAdaMergingAlgorithm(
33
+ ModelFusionAlgorithm,
34
+ LightningFabricMixin,
35
+ SimpleProfilerMixin,
36
+ ):
37
+ _program: "FabricModelFusionProgram"
38
+ """The program that this algorithm is running on."""
39
+
40
+ """
41
+ Implements the Layer-Wise AdaMerging Algorithm.
42
+
43
+ This class merges the layers of a pretrained model with those of several fine-tuned models.
44
+ The merging is controlled by layer-wise weights, which can be initialized based on a provided configuration or loaded from a file.
45
+ """
46
+
47
+ def __init__(self, algorithm_config: DictConfig):
48
+ """
49
+ Initialize the LayerWiseAdaMergingAlgorithm with the given configuration.
50
+
51
+ Args:
52
+ algorithm_config (DictConfig): The configuration for the algorithm.
53
+ """
54
+ super().__init__(algorithm_config)
55
+
56
+ def construct_layer_wise_merged_model(self, modelpool: "ModelPool"):
57
+ """
58
+ Constructs a wrapped layer-wise merged model from model pool.
59
+
60
+ This method creates a new wrapped model by merging the layers of a pretrained model with those of several fine-tuned models.
61
+ The merging is controlled by layer-wise weights, which is a `torch.Tensor` of the shape `(num_models, num_layers)`.
62
+ The merging weights can be initialized based on a provided configuration or loaded from a file.
63
+
64
+ Args:
65
+ modelpool (ModelPool): An object containing the pretrained model and fine-tuned models to be merged.
66
+
67
+ Returns:
68
+ LayerWiseMergedModel: An instance of the merged model with layer-wise weights applied.
69
+ """
70
+ pretrained_model = modelpool.load_model("_pretrained_")
71
+ finetuned_models = [
72
+ modelpool.load_model(name) for name in modelpool.model_names
73
+ ]
74
+
75
+ # initialize layer-wise weights using the provided configuration `init_values` or load from file if `weights` is provided
76
+ if self.config.weights is None:
77
+ layer_wise_weight = get_layer_wise_weights(
78
+ num_models=len(modelpool.model_names),
79
+ num_layers=len(
80
+ tuple(
81
+ filter(lambda p: p.requires_grad, pretrained_model.parameters())
82
+ )
83
+ ),
84
+ init_values=self.config.init_values,
85
+ )
86
+ else:
87
+ if isinstance(self.config.weights, str):
88
+ # self.config.weights is a path to a saved tensor
89
+ layer_wise_weight = load_tensor_from_file(self.config.weights)
90
+ else:
91
+ raise ValueError(f"Unsupported weights format: {self.config.weights}")
92
+
93
+ module = LayerWiseMergedModel(
94
+ layer_wise_weight=layer_wise_weight,
95
+ pretrained_model=pretrained_model,
96
+ finetuned_models=finetuned_models,
97
+ clamp_weights=self.config.clamp_weights,
98
+ tie_weights=self.config.tie_weights,
99
+ strict=self.config.strict,
100
+ )
101
+ print(f"{layer_wise_weight.size()=}, {layer_wise_weight.numel()=}")
102
+ return module
103
+
104
+ @rank_zero_only
105
+ def save_merging_weights(self, file_path: str, merging_weights: torch.Tensor):
106
+ """
107
+ Save the merging weights to a file.
108
+
109
+ Args:
110
+ file_path (str): The path to save the merging weights.
111
+ merging_weights (torch.Tensor): The merging weights to save.
112
+ """
113
+ if self.fabric.is_global_zero and self.config.get(
114
+ "save_merging_weights", False
115
+ ):
116
+ if isinstance(file_path, str) and not file_path.startswith(("/", ".")):
117
+ # if the file path is not absolute or relative to current working directory, save it in the log directory
118
+ save_path = os.path.join(self.log_dir, file_path)
119
+ else:
120
+ save_path = file_path
121
+ log.info(f"saving merging weights to {save_path}.")
122
+ if os.path.dirname(save_path):
123
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
124
+ torch.save(merging_weights.detach().cpu(), save_path)
125
+
126
+ def run(self, modelpool: ModelPool, **kwargs):
127
+ """
128
+ Run the Layer-Wise AdaMerging Algorithm.
129
+
130
+ This method constructs the wrapped model and performs test-time adaptation if necessary.
131
+
132
+ Args:
133
+ modelpool (ModelPool): The model pool containing the pretrained and fine-tuned models.
134
+
135
+ Returns:
136
+ LayerWiseMergedModel: The merged model after test-time adaptation.
137
+ """
138
+ log.info("Fusing models using layer-wise adaptive merging.")
139
+ self.modelpool = modelpool
140
+ self.log_hyperparams(self.config)
141
+
142
+ with self.profile("construct the wrapped model"):
143
+ module = self.construct_layer_wise_merged_model(modelpool)
144
+
145
+ if self.config.weights is not None:
146
+ # skip the test-time adaptation
147
+ return module.merge_and_unload()
148
+ else:
149
+ with self.profile("test-time adaptation"):
150
+ module = self.test_time_adaptation(module)
151
+ if self.config.get("save_merging_weights", False):
152
+ self.save_merging_weights(
153
+ self.config.save_merging_weights, module.merge_weight
154
+ )
155
+ return module.merge_and_unload()
156
+
157
+ def on_test_time_adaptation_start(self):
158
+ """
159
+ Something to do before the test-time adaptation starts. Such as setting up the task-specific heads.
160
+ """
161
+ pass
162
+
163
+ @abstractmethod
164
+ def get_shuffled_test_loader_iter(self, task: str) -> DataLoader:
165
+ """
166
+ Loader of test dataset for test-time adaptation. labels are not needed.
167
+
168
+ Args:
169
+ task (str): The name of the task.
170
+
171
+ Returns:
172
+ DataLoader: The data loader for the test dataset.
173
+ """
174
+ pass
175
+
176
+ @abstractmethod
177
+ def compute_logits(self, module, images: Tensor, task: str) -> Tensor:
178
+ """
179
+ Compute the logits for the given images and task.
180
+
181
+ Args:
182
+ module: The model module.
183
+ images (Tensor): The input images.
184
+ task (str): The name of the task.
185
+
186
+ Returns:
187
+ Tensor: The computed logits.
188
+ """
189
+ pass
190
+
191
+ def test_time_adaptation(self, module: "LayerWiseMergedModel[TorchModelType]"):
192
+ """
193
+ Perform test-time adaptation on the merged model.
194
+
195
+ This method adapts the merging weights during test-time to improve performance.
196
+
197
+ Args:
198
+ module (LayerWiseMergedModel): The merged model.
199
+
200
+ Returns:
201
+ LayerWiseMergedModel: The adapted merged model.
202
+ """
203
+ self.on_test_time_adaptation_start()
204
+
205
+ # configure optimizer
206
+ if self.config.optimizer == "adam":
207
+ optimizer = torch.optim.Adam([module.merge_weight], lr=self.config.lr)
208
+ print(f"{optimizer=}")
209
+ module, optimizer = self.fabric.setup(module, optimizer)
210
+ else:
211
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
212
+
213
+ module.train()
214
+ module.merge_weights()
215
+ for step_idx in (
216
+ pbar := tqdm(
217
+ range(self.config.max_steps if not self.is_debug_mode else 1),
218
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
219
+ + "AdaMerging Test-time adaptation",
220
+ dynamic_ncols=True,
221
+ )
222
+ ):
223
+ # default behavior for first-order optimizers
224
+ for task in self.modelpool.model_names:
225
+ with self.profile("data loading"):
226
+ batch = next(self.get_shuffled_test_loader_iter(task))
227
+ with self.profile("forward pass"):
228
+ logits = self.compute_logits(module, batch[0], task)
229
+ loss = entropy_loss(logits)
230
+ with self.profile("backward pass"):
231
+ self.fabric.backward(loss, retain_graph=True)
232
+
233
+ with self.profile("optimizer step"):
234
+ optimizer.step()
235
+ optimizer.zero_grad()
236
+ with self.profile("merging weights"):
237
+ module.merge_weights()
238
+
239
+ metrics = {
240
+ "train/loss": loss.item(),
241
+ "train/weight_max": module.merge_weight.max().item(),
242
+ "train/weight_min": module.merge_weight.min().item(),
243
+ "train/weight_mean": module.merge_weight.mean().item(),
244
+ }
245
+ self.fabric.log_dict(metrics, step=step_idx)
246
+ pbar.set_postfix(metrics)
247
+
248
+ log.info(get_memory_usage(f"after adamerging, the memory usage of GPU is:"))
249
+ self.print_profile_summary()
250
+ return module
@@ -10,6 +10,6 @@ Reference:
10
10
  from .iso import (
11
11
  ISO_C_Merge,
12
12
  ISO_CTS_Merge,
13
- IsotropicMergingInCommonSubspace,
14
13
  IsotropicMergingInCommonAndTaskSubspace,
14
+ IsotropicMergingInCommonSubspace,
15
15
  )
@@ -6,11 +6,11 @@ from fusion_bench import BaseAlgorithm, BaseModelPool
6
6
  from fusion_bench.mixins import LightningFabricMixin
7
7
  from fusion_bench.utils.state_dict_arithmetic import (
8
8
  state_dict_add,
9
- state_dict_sub,
10
9
  state_dict_mul,
10
+ state_dict_sub,
11
11
  )
12
12
 
13
- from .iso_utils import iso_c, iso_cts, check_parameterNamesMatch
13
+ from .iso_utils import check_parameterNamesMatch, iso_c, iso_cts
14
14
 
15
15
 
16
16
  class IsotropicMergingInCommonSubspace(BaseAlgorithm, LightningFabricMixin):