fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. fusion_bench/__main__.py +4 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/__init__.py +26 -4
  4. fusion_bench/method/classification/__init__.py +1 -0
  5. fusion_bench/method/classification/clip_finetune.py +1 -3
  6. fusion_bench/method/classification/continual_clip_finetune.py +297 -0
  7. fusion_bench/method/dare/__init__.py +1 -0
  8. fusion_bench/method/dare/task_arithmetic.py +14 -7
  9. fusion_bench/method/dare/ties_merging.py +100 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  11. fusion_bench/method/isotropic_merging/iso.py +114 -0
  12. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  13. fusion_bench/method/opcm/__init__.py +4 -0
  14. fusion_bench/method/opcm/opcm.py +277 -0
  15. fusion_bench/method/opcm/task_arithmetic.py +115 -0
  16. fusion_bench/method/opcm/ties_merging.py +156 -0
  17. fusion_bench/method/opcm/utils.py +73 -0
  18. fusion_bench/method/opcm/weight_average.py +120 -0
  19. fusion_bench/method/slerp/slerp.py +1 -1
  20. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  21. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
  22. fusion_bench/method/ties_merging/ties_merging.py +10 -0
  23. fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
  24. fusion_bench/mixins/clip_classification.py +4 -1
  25. fusion_bench/programs/fabric_fusion_program.py +22 -11
  26. fusion_bench/scripts/cli.py +1 -0
  27. fusion_bench/taskpool/base_pool.py +1 -1
  28. fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
  29. fusion_bench/utils/__init__.py +2 -1
  30. fusion_bench/utils/dict.py +43 -0
  31. fusion_bench/utils/expr.py +90 -0
  32. fusion_bench/utils/fabric.py +17 -0
  33. fusion_bench/utils/instantiate.py +7 -1
  34. fusion_bench/utils/json.py +30 -0
  35. fusion_bench/utils/parameters.py +27 -7
  36. fusion_bench/utils/path.py +15 -0
  37. fusion_bench/utils/plot/color_data.py +1726 -0
  38. fusion_bench/utils/rich_utils.py +15 -0
  39. fusion_bench/utils/set.py +8 -0
  40. fusion_bench/utils/tensorboard.py +51 -0
  41. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
  42. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
  43. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
  44. fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
  45. fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
  46. fusion_bench_config/method/clip_finetune.yaml +2 -2
  47. fusion_bench_config/method/dare/ties_merging.yaml +15 -0
  48. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  49. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  50. fusion_bench_config/method/opcm/opcm.yaml +12 -0
  51. fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
  52. fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
  53. fusion_bench_config/method/opcm/weight_average.yaml +10 -0
  54. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  55. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
  56. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
  57. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
  58. {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
1
+ import math
2
+ from typing import List
3
+
4
+ import torch
5
+
6
+ from fusion_bench.utils import timeit_context
7
+ from fusion_bench.utils.type import StateDictType
8
+
9
+
10
+ def iso_c(
11
+ task_vectors: List[StateDictType],
12
+ accelerator="cuda",
13
+ exclude_keys: List[str] = None,
14
+ ) -> StateDictType:
15
+ exclude_keys = [] if exclude_keys is None else exclude_keys
16
+
17
+ with torch.no_grad(), timeit_context("ISO-C Merging"):
18
+ new_vector = {}
19
+ for key in task_vectors[0]:
20
+ print(f"Merging {key}...")
21
+ original_device = task_vectors[0][key].device
22
+ tvs = [
23
+ task_vector[key].to(device=accelerator, non_blocking=True)
24
+ for task_vector in task_vectors
25
+ ]
26
+ num_tvs = len(tvs)
27
+ new_vector[key] = sum(tvs) / num_tvs
28
+ del tvs # free memory
29
+
30
+ if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
31
+ # if the key is a 2D matrix, we need to merge the task vectors in the common space
32
+ new_vector[key] *= num_tvs
33
+ U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False)
34
+ S_mean = torch.ones_like(S) * S.mean()
35
+
36
+ new_vector[key] = torch.linalg.multi_dot(
37
+ (
38
+ U,
39
+ torch.diag(S_mean),
40
+ V,
41
+ )
42
+ )
43
+ new_vector[key] = new_vector[key].to(
44
+ device=original_device, non_blocking=True
45
+ )
46
+ return new_vector
47
+
48
+
49
+ @torch.no_grad()
50
+ def iso_cts(
51
+ task_vectors: List[StateDictType],
52
+ common_space_fraction: float,
53
+ accelerator: str = "cuda",
54
+ exclude_keys: List[str] = None,
55
+ ):
56
+ exclude_keys = [] if exclude_keys is None else exclude_keys
57
+ new_vector = {}
58
+
59
+ print("ISO-CTS Merging")
60
+ for key in task_vectors[0]:
61
+ shape_ = task_vectors[0][key].shape
62
+ original_device = task_vectors[0][key].device
63
+ is_2d_matrix = (len(shape_) == 2) and (key not in exclude_keys)
64
+ if not is_2d_matrix:
65
+ print(f"Combining by avg {key}...")
66
+ for i, task_vector in enumerate(task_vectors):
67
+ vec = task_vector[key].to(device=accelerator, non_blocking=True)
68
+ if i == 0:
69
+ new_vector[key] = vec.clone()
70
+ else:
71
+ new_vector[key] += (vec - new_vector[key]) / (i + 1)
72
+
73
+ # move the new vector to the original device
74
+ new_vector[key] = new_vector[key].to(
75
+ device=original_device, non_blocking=True
76
+ )
77
+ continue
78
+
79
+ print(f"Computing common space using sum for {key}...")
80
+ combined_w = sum(
81
+ [
82
+ task_vector[key].to(device=accelerator, non_blocking=True)
83
+ for task_vector in task_vectors
84
+ ]
85
+ )
86
+
87
+ ### Calculate the common space size (making sure that task specific space is equally divisible) ###
88
+ common_space_index_s = int(min(shape_) * common_space_fraction)
89
+ _task_specific_total_space_index_s = round(
90
+ (min(shape_) - common_space_index_s) / len(task_vectors)
91
+ ) * len(task_vectors)
92
+ common_space_index_s = min(shape_) - _task_specific_total_space_index_s
93
+
94
+ u, s, v = torch.linalg.svd(combined_w, full_matrices=False)
95
+ common_space_u = u[:, :common_space_index_s]
96
+ common_space_s = s[:common_space_index_s]
97
+ common_space_v = v[:common_space_index_s, :]
98
+ ###################################################################
99
+
100
+ ### Calculate task specific space ###
101
+ n_dims_per_task = int((min(shape_) - common_space_index_s) / len(task_vectors))
102
+ for i, task_vector in enumerate(task_vectors):
103
+ w = task_vector[key].to(device=accelerator)
104
+
105
+ # calculate the projection onto task specific space to remove the common space
106
+ w_ts = w - common_space_u @ common_space_u.T @ w
107
+ u_ts, s_ts, v_ts = torch.linalg.svd(w_ts, full_matrices=False)
108
+
109
+ if i == 0:
110
+ combined_space_u = torch.zeros_like(u_ts, device=accelerator)
111
+ combined_space_s = torch.zeros_like(s_ts, device=accelerator)
112
+ combined_space_v = torch.zeros_like(v_ts, device=accelerator)
113
+
114
+ combined_space_u[:, i * n_dims_per_task : (i + 1) * n_dims_per_task] = u_ts[
115
+ :, :n_dims_per_task
116
+ ]
117
+ combined_space_s[i * n_dims_per_task : (i + 1) * n_dims_per_task] = s_ts[
118
+ :n_dims_per_task
119
+ ]
120
+ combined_space_v[i * n_dims_per_task : (i + 1) * n_dims_per_task, :] = v_ts[
121
+ :n_dims_per_task, :
122
+ ]
123
+ ###################################################################
124
+
125
+ combined_space_u[
126
+ :,
127
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
128
+ + common_space_index_s,
129
+ ] = common_space_u
130
+ combined_space_s[
131
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
132
+ + common_space_index_s
133
+ ] = common_space_s
134
+ combined_space_v[
135
+ len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
136
+ + common_space_index_s,
137
+ :,
138
+ ] = common_space_v
139
+
140
+ ### Orthogonalize combined_space_u and combined_space_v ###
141
+ u_combined_space_u, s_combined_space_u, v_combined_space_u = torch.linalg.svd(
142
+ combined_space_u, full_matrices=False
143
+ )
144
+ u_combined_space_v, s_combined_space_v, v_combined_space_v = torch.linalg.svd(
145
+ combined_space_v, full_matrices=False
146
+ )
147
+ combined_space_u = u_combined_space_u @ v_combined_space_u
148
+ combined_space_v = u_combined_space_v @ v_combined_space_v
149
+ ###################################################################
150
+
151
+ combined_space_s = torch.ones_like(combined_space_s) * combined_space_s.mean()
152
+
153
+ new_vector[key] = torch.linalg.multi_dot(
154
+ (
155
+ combined_space_u,
156
+ torch.diag(combined_space_s),
157
+ combined_space_v,
158
+ )
159
+ )
160
+ new_vector[key] = new_vector[key].to(device=original_device, non_blocking=True)
161
+
162
+ return new_vector
163
+
164
+
165
+ def check_parameterNamesMatch(checkpoints):
166
+ parameter_names = set(checkpoints[0].keys())
167
+
168
+ if len(checkpoints) >= 2:
169
+ # raise ValueError("Number of models is less than 2.")
170
+ for checkpoint in checkpoints[1:]:
171
+ current_parameterNames = set(checkpoint.keys())
172
+ if current_parameterNames != parameter_names:
173
+ raise ValueError(
174
+ "Differing parameter names in models. "
175
+ f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
176
+ )
@@ -0,0 +1,4 @@
1
+ from .opcm import OPCMForCLIP
2
+ from .task_arithmetic import ContinualTaskArithmeticForCLIP
3
+ from .ties_merging import ContinualTiesMergingForCLIP
4
+ from .weight_average import ContinualWeightAverageForCLIP
@@ -0,0 +1,277 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
8
+
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from torch import Tensor, nn
14
+ from tqdm.auto import tqdm
15
+ from transformers import CLIPVisionModel
16
+
17
+ from fusion_bench import BaseAlgorithm, BaseModelPool
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
+ from fusion_bench.utils import instantiate
21
+ from fusion_bench.utils.json import load_from_json, save_to_json
22
+ from fusion_bench.utils.parameters import state_dict_to_vector
23
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
24
+
25
+ from .utils import frobenius_inner_product, get_task_vector_norm, is_leaf_module, svd
26
+
27
+ if TYPE_CHECKING:
28
+ from torch.utils.tensorboard import SummaryWriter
29
+
30
+
31
+ class OPCMForCLIP(
32
+ BaseAlgorithm,
33
+ LightningFabricMixin,
34
+ ):
35
+ def __init__(
36
+ self,
37
+ alpha: float,
38
+ shuffle_order: bool = True,
39
+ seed: Optional[int] = None,
40
+ save_on_every_step: bool = True,
41
+ evaluate_on_every_step: bool = False,
42
+ **kwargs,
43
+ ):
44
+ """
45
+ Continual Model Merging via SVD Projection.
46
+
47
+ Args:
48
+ alpha (float): the scaling factor for the SVD projection.
49
+ shuffle_order (bool): whether to shuffle the order of the models.
50
+ seed (Optional[int]): the seed to use.
51
+ save_on_every_step (bool): whether to save the merged model on every step.
52
+ evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
53
+ """
54
+ self.alpha = alpha
55
+ self.shuffle_order = shuffle_order
56
+ self.seed = seed
57
+ self.save_on_every_step = save_on_every_step
58
+ self.evaluate_on_every_step = evaluate_on_every_step
59
+ super().__init__(**kwargs)
60
+
61
+ @torch.no_grad()
62
+ def run(self, modelpool: BaseModelPool):
63
+ if self.seed is not None:
64
+ L.seed_everything(self.seed)
65
+ accelerator = self.fabric.device
66
+
67
+ pretrained_model = modelpool.load_pretrained_model()
68
+
69
+ model_names = modelpool.model_names
70
+ if self.shuffle_order:
71
+ random.shuffle(model_names)
72
+
73
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
74
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
75
+ """Configuration for the test datasets"""
76
+
77
+ # log the model names
78
+ if self.log_dir is not None:
79
+ save_to_json(model_names, Path(self.log_dir) / "model_names.json")
80
+ tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
81
+ tensorboard_summarywriter.add_text(
82
+ "global/model_names", str(model_names), global_step=0
83
+ )
84
+
85
+ # get the average model
86
+ merged_model = modelpool.load_model(model_names[0])
87
+
88
+ if self.evaluate_on_every_step:
89
+ self.taskpool._is_setup = False
90
+ self.taskpool._test_datasets = DictConfig(
91
+ {model_names[0]: self._test_datasets[model_names[0]]}
92
+ )
93
+ report = self.taskpool.evaluate(deepcopy(merged_model))
94
+ save_to_json(report, Path(self.log_dir) / "report_0.json")
95
+
96
+ self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
97
+ self.all_task_vector_norm = [self.avg_task_vector_norm]
98
+ self.fabric.log("model/task_vector_norm", self.avg_task_vector_norm, step=0)
99
+ self.fabric.log("model/avg_task_vector_norm", self.avg_task_vector_norm, step=0)
100
+ self.fabric.log(
101
+ "model/merged_task_vector_norm", self.avg_task_vector_norm, step=0
102
+ )
103
+
104
+ self.previous_lambda_t = 1
105
+ self.lambda_t = None
106
+ self.fabric.log("model/lambda_t", self.previous_lambda_t, step=0)
107
+ self.fabric.log("empirical/lambda_t", 1, step=0)
108
+
109
+ if self.save_on_every_step:
110
+ self.save_merged_model(merged_model, 0)
111
+
112
+ for model_idx, model_name in tqdm(
113
+ enumerate(model_names[1:]), desc="Processing models"
114
+ ):
115
+ model_idx += 1
116
+ task_model = modelpool.load_model(model_name)
117
+
118
+ self.all_task_vector_norm.append(
119
+ get_task_vector_norm(task_model, pretrained_model)
120
+ )
121
+ self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
122
+ self.fabric.log(
123
+ "model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
124
+ )
125
+ self.fabric.log(
126
+ "model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
127
+ )
128
+
129
+ self.lambda_t = 1 # temporary value
130
+
131
+ for module_name, module in tqdm(
132
+ list(merged_model.named_modules()),
133
+ desc=f"Processing {model_name}",
134
+ leave=False,
135
+ ):
136
+ if not is_leaf_module(module):
137
+ continue
138
+
139
+ if isinstance(module, nn.Linear):
140
+ module.weight.data = self.merge_linear_weights(
141
+ module.weight,
142
+ pretrained_model.get_submodule(module_name).weight,
143
+ task_model.get_submodule(module_name).weight,
144
+ param_name=".".join([module_name, "weight"]),
145
+ alpha=self.alpha,
146
+ accelerator=accelerator,
147
+ )
148
+ if module.bias is not None:
149
+ module.bias.data = self.merge_other_parameters(
150
+ module.bias,
151
+ pretrained_model.get_submodule(module_name).bias,
152
+ task_model.get_submodule(module_name).bias,
153
+ param_name=".".join([module_name, "bias"]),
154
+ accelerator=accelerator,
155
+ )
156
+ else:
157
+ for param_name, param in module.named_parameters():
158
+ param.data = self.merge_other_parameters(
159
+ merged_W=param,
160
+ pretrained_W=pretrained_model.get_submodule(
161
+ module_name
162
+ ).get_parameter(param_name),
163
+ task_W=task_model.get_submodule(module_name).get_parameter(
164
+ param_name
165
+ ),
166
+ param_name=".".join([module_name, param_name]),
167
+ accelerator=accelerator,
168
+ )
169
+
170
+ task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
171
+ self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
172
+ for param_name, param in merged_model.named_parameters():
173
+ param.data = pretrained_model.get_parameter(param_name) + (
174
+ param - pretrained_model.get_parameter(param_name)
175
+ ) * (self.avg_task_vector_norm / task_vector_norm)
176
+ self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
177
+ self.fabric.log(
178
+ "empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
179
+ )
180
+ self.previous_lambda_t = self.lambda_t
181
+ self.lambda_t = None
182
+
183
+ self.fabric.log(
184
+ "model/merged_task_vector_norm",
185
+ get_task_vector_norm(merged_model, pretrained_model),
186
+ step=model_idx,
187
+ )
188
+
189
+ if self.save_on_every_step:
190
+ self.save_merged_model(merged_model, model_idx)
191
+
192
+ if self.evaluate_on_every_step:
193
+ self.taskpool._is_setup = False
194
+ self.taskpool._test_datasets = DictConfig(
195
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
196
+ )
197
+ report = self.taskpool.evaluate(deepcopy(merged_model))
198
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
199
+
200
+ return merged_model
201
+
202
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
203
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
204
+ merged_model.save_pretrained(
205
+ Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
206
+ )
207
+
208
+ def merge_linear_weights(
209
+ self,
210
+ merged_W: Tensor,
211
+ pretrained_W: Tensor,
212
+ task_W: Tensor,
213
+ param_name: str,
214
+ alpha: float,
215
+ accelerator: str = "cpu",
216
+ ):
217
+ original_device = merged_W.device
218
+ merged_W = merged_W.to(accelerator)
219
+ pretrained_W = pretrained_W.to(accelerator)
220
+ task_W = task_W.to(accelerator)
221
+
222
+ previous_merged_tv = merged_W - pretrained_W
223
+ task_tv = task_W - pretrained_W
224
+
225
+ u, s, v = svd(previous_merged_tv)
226
+ rank = s.size(0)
227
+ split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
228
+
229
+ projected_task_tv = u.T @ task_tv @ v
230
+ projected_task_tv.diag().fill_(0)
231
+
232
+ projected_task_tv[:split_rank, :split_rank] = 0
233
+
234
+ cleaned_task_tv = u @ projected_task_tv @ v.T
235
+
236
+ previous_lambda_t = self.previous_lambda_t
237
+ lambda_t = self.lambda_t
238
+ new_merged_W = (
239
+ pretrained_W
240
+ + (previous_lambda_t * previous_merged_tv + cleaned_task_tv) / lambda_t
241
+ )
242
+ return new_merged_W.to(original_device)
243
+
244
+ def merge_other_parameters(
245
+ self,
246
+ merged_W: Tensor,
247
+ pretrained_W: Tensor,
248
+ task_W: Tensor,
249
+ param_name: str,
250
+ accelerator: str = "cpu",
251
+ ):
252
+ original_device = merged_W.device
253
+ merged_W = merged_W.to(accelerator)
254
+ pretrained_W = pretrained_W.to(accelerator)
255
+ task_W = task_W.to(accelerator)
256
+
257
+ previous_merged_tv = merged_W - pretrained_W
258
+ task_tv = task_W - pretrained_W
259
+
260
+ previous_lambda_t = self.previous_lambda_t
261
+ lambda_t = self.lambda_t
262
+
263
+ new_merged_W = (
264
+ pretrained_W + (previous_lambda_t * previous_merged_tv + task_tv) / lambda_t
265
+ )
266
+ return new_merged_W.to(original_device)
267
+
268
+ def compute_lambda_t(
269
+ self, previous_merged_tv: Tensor, task_tv: Tensor, previous_lambda_t: float
270
+ ):
271
+ previous_merged_tv = torch.flatten(previous_merged_tv)
272
+ task_tv = torch.flatten(task_tv)
273
+
274
+ lambda_t = torch.linalg.vector_norm(
275
+ previous_lambda_t * previous_merged_tv + task_tv
276
+ ) / torch.linalg.vector_norm(previous_merged_tv)
277
+ return lambda_t.item()
@@ -0,0 +1,115 @@
1
+ import os
2
+ import random
3
+ import time
4
+ from collections import defaultdict
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
8
+
9
+ import lightning as L
10
+ import numpy as np
11
+ import torch
12
+ from omegaconf import DictConfig
13
+ from torch import Tensor, nn
14
+ from tqdm.auto import tqdm
15
+ from transformers import CLIPVisionModel
16
+
17
+ from fusion_bench import BaseAlgorithm, BaseModelPool
18
+ from fusion_bench.mixins import LightningFabricMixin
19
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
20
+ from fusion_bench.utils.json import load_from_json, save_to_json
21
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
22
+
23
+ if TYPE_CHECKING:
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+
27
+ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
28
+ def __init__(
29
+ self,
30
+ scaling_factor: float,
31
+ shuffle_order: bool = True,
32
+ seed: Optional[int] = None,
33
+ save_on_every_step: bool = True,
34
+ evaluate_on_every_step: bool = False,
35
+ **kwargs,
36
+ ):
37
+ """
38
+ Continual Model Merging via Task Arithmetic.
39
+
40
+ Args:
41
+ scaling_factor (float): the scaling factor to use.
42
+ shuffle_order (bool): whether to shuffle the order of the models.
43
+ seed (Optional[int]): the seed to use.
44
+ save_on_every_step (bool): whether to save the merged model on every step.
45
+ evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
46
+ """
47
+ self.scaling_factor = scaling_factor
48
+ self.shuffle_order = shuffle_order
49
+ self.seed = seed
50
+ self.save_on_every_step = save_on_every_step
51
+ self.evaluate_on_every_step = evaluate_on_every_step
52
+ super().__init__(**kwargs)
53
+
54
+ @torch.no_grad()
55
+ def run(self, modelpool: BaseModelPool):
56
+ if self.seed is not None:
57
+ L.seed_everything(self.seed)
58
+
59
+ model_names = modelpool.model_names
60
+ if self.shuffle_order:
61
+ random.shuffle(model_names)
62
+
63
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
64
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
65
+ """Configuration for the test datasets"""
66
+
67
+ # log the model names
68
+ if self.log_dir is not None:
69
+ save_to_json(model_names, Path(self.log_dir) / "model_names.json")
70
+ tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
71
+ tensorboard_summarywriter.add_text(
72
+ "global/model_names", str(model_names), global_step=0
73
+ )
74
+
75
+ # get the average model
76
+ pretrained_model = modelpool.load_pretrained_model()
77
+ merged_model = deepcopy(pretrained_model)
78
+
79
+ for model_idx, model_name in tqdm(
80
+ enumerate(model_names), desc="Processing models"
81
+ ):
82
+ task_model = modelpool.load_model(model_name)
83
+
84
+ for param_name, param in task_model.named_parameters():
85
+ if not param.requires_grad:
86
+ continue
87
+
88
+ task_param = param
89
+ merged_param = merged_model.get_parameter(param_name)
90
+ pretrained_param = pretrained_model.get_parameter(param_name)
91
+
92
+ new_param = merged_param + self.scaling_factor * (
93
+ task_param - pretrained_param
94
+ )
95
+ merged_model.get_parameter(param_name).data = new_param
96
+
97
+ if self.save_on_every_step:
98
+ self.save_merged_model(merged_model, model_idx)
99
+
100
+ if self.evaluate_on_every_step:
101
+ self.taskpool._is_setup = False
102
+ self.taskpool._test_datasets = DictConfig(
103
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
104
+ )
105
+ report = self.taskpool.evaluate(deepcopy(merged_model))
106
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
107
+
108
+ return merged_model
109
+
110
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
111
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
112
+ torch.save(
113
+ merged_model.state_dict(),
114
+ Path(self.log_dir) / "checkpoints" / f"model_{step}.pth",
115
+ )