fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +12 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/clip_finetune.py +6 -4
  7. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  8. fusion_bench/method/dop/__init__.py +1 -0
  9. fusion_bench/method/dop/dop.py +366 -0
  10. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  11. fusion_bench/method/dop/utils.py +73 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  15. fusion_bench/mixins/__init__.py +2 -0
  16. fusion_bench/mixins/pyinstrument.py +174 -0
  17. fusion_bench/mixins/simple_profiler.py +106 -23
  18. fusion_bench/modelpool/__init__.py +2 -0
  19. fusion_bench/modelpool/base_pool.py +77 -14
  20. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  21. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  22. fusion_bench/models/__init__.py +35 -9
  23. fusion_bench/optim/__init__.py +40 -2
  24. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  25. fusion_bench/optim/muon.py +339 -0
  26. fusion_bench/programs/__init__.py +2 -0
  27. fusion_bench/programs/fabric_fusion_program.py +2 -2
  28. fusion_bench/programs/fusion_program.py +271 -0
  29. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  30. fusion_bench/utils/__init__.py +167 -21
  31. fusion_bench/utils/lazy_imports.py +91 -12
  32. fusion_bench/utils/lazy_state_dict.py +55 -5
  33. fusion_bench/utils/misc.py +104 -13
  34. fusion_bench/utils/packages.py +4 -0
  35. fusion_bench/utils/path.py +7 -0
  36. fusion_bench/utils/pylogger.py +6 -0
  37. fusion_bench/utils/rich_utils.py +1 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  39. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
  40. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
  41. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  42. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  43. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  44. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  45. fusion_bench_config/method/dop/dop.yaml +30 -0
  46. fusion_bench_config/method/dummy.yaml +6 -0
  47. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  48. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  49. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  50. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  51. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  52. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
  53. fusion_bench_config/method/model_recombination.yaml +8 -0
  54. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  55. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  56. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  57. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  58. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  59. fusion_bench_config/method/simple_average.yaml +9 -0
  60. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  61. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  62. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  63. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  64. fusion_bench_config/method/ties_merging.yaml +3 -0
  65. fusion_bench_config/model_fusion.yaml +45 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  72. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
  73. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
  74. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
  75. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,366 @@
1
+ """
2
+ Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
3
+
4
+
5
+ Example:
6
+
7
+ fusion_bench \
8
+ method=dop/dop \
9
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
10
+ taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
11
+ """
12
+
13
+ import logging
14
+ import os
15
+ import random
16
+ from copy import deepcopy
17
+ from pathlib import Path
18
+ from typing import Dict, List, Literal, Optional, Tuple, cast
19
+
20
+ import lightning as L
21
+ import numpy as np
22
+ import torch
23
+ from omegaconf import DictConfig
24
+ from torch import Tensor, nn
25
+ from torch.autograd import Variable
26
+ from tqdm.auto import tqdm
27
+ from transformers import CLIPVisionModel
28
+
29
+ from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
30
+ from fusion_bench.method.simple_average import simple_average
31
+ from fusion_bench.mixins import LightningFabricMixin
32
+ from fusion_bench.taskpool import CLIPVisionModelTaskPool
33
+ from fusion_bench.utils import seed_everything_by_time
34
+ from fusion_bench.utils.json import save_to_json
35
+
36
+ from .min_norm_solvers import MinNormSolver, gradient_normalizers
37
+ from .utils import is_leaf_module, svd
38
+
39
+ log = logging.getLogger(__name__)
40
+
41
+
42
+ @auto_register_config
43
+ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
44
+
45
+ def __init__(
46
+ self,
47
+ seed: Optional[int] = None,
48
+ shuffle_order: bool = False,
49
+ save_on_every_step: bool = True,
50
+ evaluate_on_every_step: bool = False,
51
+ lr: float = 1e-4,
52
+ num_steps: int = 200,
53
+ mgda: bool = True,
54
+ ema: bool = True,
55
+ ema_beta: float = 0.99,
56
+ alpha: float = None,
57
+ svd_epsilon: float = 1.0,
58
+ svd_proj_space: str = "uv",
59
+ **kwargs,
60
+ ):
61
+ self.lr = lr
62
+ self.num_steps = num_steps
63
+ self.mgda = mgda
64
+ self.ema = ema
65
+ self.ema_beta = ema_beta
66
+ self.alpha = alpha
67
+ self.svd_epsilon = svd_epsilon
68
+ self.svd_proj_space = svd_proj_space
69
+ self.seed = seed
70
+ self.shuffle_order = shuffle_order
71
+ self.save_on_every_step = save_on_every_step
72
+ self.evaluate_on_every_step = evaluate_on_every_step
73
+
74
+ assert (
75
+ self.svd_epsilon >= 0 and self.svd_epsilon <= 1
76
+ ), "The svd_epsilon should be in the range of [0, 1]"
77
+ assert (
78
+ self.alpha >= 0 and self.alpha <= 1
79
+ ), "The alpha should be in the range of [0, 1]"
80
+ super().__init__(**kwargs)
81
+
82
+ def print_params(self, pretrained_model):
83
+ total_params = 0
84
+ linear_params = 0
85
+ linear_weight_params = 0
86
+ for module_name, module in pretrained_model.named_modules():
87
+ if not is_leaf_module(module):
88
+ continue
89
+ if isinstance(module, nn.Linear):
90
+ linear_params += sum(p.numel() for n, p in module.named_parameters())
91
+ linear_weight_params += sum(
92
+ p.numel() for n, p in module.named_parameters() if "weight" in n
93
+ )
94
+ total_params += sum(p.numel() for p in module.parameters())
95
+
96
+ linear_ratio = linear_params / total_params * 100
97
+ linear_weight_ratio = linear_weight_params / total_params * 100
98
+ print(f"Total Parameters: {total_params}")
99
+ print(f"Linear Parameters: {linear_params}")
100
+ print(f"Linear Weight Parameters: {linear_weight_params}")
101
+ print(f"Linear Ratio: {linear_ratio:.2f}%")
102
+ print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
103
+
104
+ def run(self, modelpool: BaseModelPool):
105
+ if self.seed is not None:
106
+ L.seed_everything(self.seed)
107
+ else:
108
+ seed_everything_by_time(self.fabric)
109
+
110
+ # get the model names, shuffle if needed
111
+ # the model names will be saved to the log directory as `model_names.json`
112
+ model_names = modelpool.model_names
113
+ if self.shuffle_order:
114
+ random.shuffle(model_names)
115
+ if self.log_dir is not None:
116
+ save_to_json(model_names, os.path.join(self.log_dir, "model_names.json"))
117
+
118
+ if self.evaluate_on_every_step:
119
+ """Configuration for the test datasets"""
120
+ self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
121
+ self._test_datasets = deepcopy(self.taskpool._test_datasets)
122
+
123
+ pretrained_model = modelpool.load_pretrained_model()
124
+
125
+ merged_model = None
126
+ for model_idx, model_name in enumerate(model_names):
127
+ print(
128
+ f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
129
+ )
130
+ if model_idx == 0:
131
+ merged_model = modelpool.load_model(model_names[0])
132
+ else:
133
+ merged_model = self._layer_wise_optimize(
134
+ model_names=["merged", model_name],
135
+ pretrained_model=deepcopy(pretrained_model),
136
+ finetuned_models={
137
+ "merged": merged_model,
138
+ model_name: modelpool.load_model(model_name),
139
+ },
140
+ model_idx=model_idx,
141
+ )
142
+
143
+ if self.save_on_every_step:
144
+ self.save_merged_model(merged_model, model_idx)
145
+
146
+ if self.evaluate_on_every_step:
147
+ self.taskpool._is_setup = False
148
+ self.taskpool._test_datasets = DictConfig(
149
+ {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
150
+ )
151
+ report = self.taskpool.evaluate(deepcopy(merged_model))
152
+ save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
153
+
154
+ return merged_model
155
+
156
+ def _layer_wise_optimize(
157
+ self,
158
+ model_names: List[str],
159
+ pretrained_model: nn.Module,
160
+ finetuned_models: Dict[str, nn.Module],
161
+ model_idx: int,
162
+ ):
163
+ time_cost = []
164
+ for module_name, module in pretrained_model.named_modules():
165
+ if not is_leaf_module(module):
166
+ continue
167
+
168
+ if isinstance(module, nn.Linear):
169
+ if module.weight.requires_grad:
170
+ import time
171
+
172
+ start_time = time.time()
173
+ merged_weight = self._optimize_weight(
174
+ module.weight,
175
+ {
176
+ model_name: finetuned_models[model_name]
177
+ .get_submodule(module_name)
178
+ .weight
179
+ for model_name in model_names
180
+ },
181
+ module_name,
182
+ model_idx,
183
+ )
184
+ end_time = time.time()
185
+ time_cost.append(end_time - start_time)
186
+ module.weight.data = merged_weight.data
187
+ else:
188
+ module.weight.data = simple_average(
189
+ [
190
+ finetuned_models[model_name]
191
+ .get_submodule(module_name)
192
+ .weight
193
+ for model_name in model_names
194
+ ]
195
+ )
196
+ if module.bias is not None:
197
+ module.bias.data = simple_average(
198
+ [
199
+ finetuned_models[model_name].get_submodule(module_name).bias
200
+ for model_name in model_names
201
+ ]
202
+ )
203
+ else:
204
+ simple_average(
205
+ [
206
+ finetuned_models[model_name].get_submodule(module_name)
207
+ for model_name in model_names
208
+ ],
209
+ base_module=module,
210
+ )
211
+
212
+ return pretrained_model
213
+
214
+ def _optimize_weight(
215
+ self,
216
+ pretrained_weight: Tensor,
217
+ finetuned_weights: Dict[str, Tensor],
218
+ module_name: str,
219
+ model_idx: int,
220
+ ):
221
+ assert (
222
+ self.fabric.world_size == 1
223
+ ), "This algorithm is not currently supported in distributed training"
224
+
225
+ pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
226
+ finetuned_weights = {
227
+ model_name: self.fabric.to_device(finetuned_weight.detach())
228
+ for model_name, finetuned_weight in finetuned_weights.items()
229
+ }
230
+
231
+ merged_weight = self.fabric.to_device(
232
+ nn.Parameter(
233
+ simple_average(
234
+ [
235
+ finetuned_weight.detach()
236
+ for finetuned_weight in finetuned_weights.values()
237
+ ]
238
+ ),
239
+ requires_grad=True,
240
+ )
241
+ )
242
+
243
+ # Compute SVD of the difference between the finetuned and pretrained weights
244
+ proj_u_dict = {}
245
+ proj_v_dict = {}
246
+ proj_s_dict = {}
247
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
248
+ finetuned_tv = finetuned_weight - pretrained_weight
249
+ u, s, v = svd(finetuned_tv, full_matrices=True)
250
+ epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
251
+ cumsum_ratio = s.cumsum(dim=0) / s.sum()
252
+ split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
253
+ u_main = u[:, :split_rank]
254
+ v_main = v[:, :split_rank]
255
+ s_main = s[:split_rank]
256
+ proj_u_dict[i] = u_main
257
+ proj_v_dict[i] = v_main
258
+ proj_s_dict[i] = s_main
259
+
260
+ if self.mgda:
261
+ if self.ema:
262
+ ema_sol = [self.alpha, 1 - self.alpha]
263
+ # This is multiple-gradient descent algorithm (MGDA) optimization
264
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
265
+ all_losses = [[], []]
266
+ all_alphas = [[], []]
267
+ for step_idx in tqdm(
268
+ range(self.num_steps), desc=f"Optimizing {module_name} weight"
269
+ ):
270
+ # Scaling the loss functions based on the algorithm choice
271
+ loss_data = {}
272
+ grads = {}
273
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
274
+ proj_u = proj_u_dict[i]
275
+ proj_v = proj_v_dict[i]
276
+ proj_s = proj_s_dict[i]
277
+ delta_tv = merged_weight - finetuned_weight
278
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
279
+ loss_data[i] = float(loss_i.data)
280
+
281
+ all_losses[i].append(float(loss_i.data))
282
+
283
+ optimizer.zero_grad()
284
+ loss_i.backward()
285
+ grads[i] = Variable(
286
+ merged_weight.grad.data.clone(), requires_grad=False
287
+ )
288
+
289
+ # Normalize all gradients
290
+ gn = gradient_normalizers(
291
+ grads=grads, losses=loss_data, normalization_type="loss"
292
+ )
293
+ for i, _ in enumerate(finetuned_weights.values()):
294
+ grads[i] = grads[i] / float(gn[i])
295
+
296
+ # Frank-Wolfe iteration to compute scales.
297
+ sol, min_norm = MinNormSolver.find_min_norm_element(
298
+ [[grads[i]] for i in range(len(finetuned_weights.values()))]
299
+ )
300
+
301
+ if self.ema:
302
+ ema_sol = [
303
+ self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
304
+ for i in range(len(sol))
305
+ ]
306
+ sol = ema_sol
307
+ all_alphas[0].append(ema_sol[0])
308
+ all_alphas[1].append(ema_sol[1])
309
+
310
+ # Scaled back-propagation
311
+ loss = 0
312
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
313
+ # Comptue gradients of each loss function wrt parameters
314
+ proj_u = proj_u_dict[i]
315
+ proj_v = proj_v_dict[i]
316
+ proj_s = proj_s_dict[i]
317
+ delta_tv = merged_weight - finetuned_weight
318
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
319
+ loss += float(sol[i]) * loss_i
320
+
321
+ optimizer.zero_grad()
322
+ loss.backward()
323
+ optimizer.step()
324
+
325
+ else:
326
+ # This is a naive weighted optimization
327
+ optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
328
+ for step_idx in tqdm(
329
+ range(self.num_steps), desc=f"Optimizing {module_name} weight"
330
+ ):
331
+ loss = 0
332
+ for i, finetuned_weight in enumerate(finetuned_weights.values()):
333
+ proj_u = proj_u_dict[i]
334
+ proj_v = proj_v_dict[i]
335
+ proj_s = proj_s_dict[i]
336
+ delta_tv = merged_weight - finetuned_weight
337
+ loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
338
+ loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
339
+
340
+ optimizer.zero_grad()
341
+ loss.backward()
342
+ optimizer.step()
343
+
344
+ return merged_weight.detach().cpu()
345
+
346
+ def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
347
+ proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
348
+ proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
349
+ loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
350
+ loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
351
+ if self.svd_proj_space == "uv":
352
+ loss_i = loss_i_u + loss_i_v
353
+ elif self.svd_proj_space == "u":
354
+ loss_i = loss_i_u
355
+ elif self.svd_proj_space == "v":
356
+ loss_i = loss_i_v
357
+ else:
358
+ raise ValueError("Invalid svd_proj_space")
359
+
360
+ return loss_i
361
+
362
+ def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
363
+ os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
364
+ merged_model.save_pretrained(
365
+ Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
366
+ )
@@ -0,0 +1,227 @@
1
+ # This code is from
2
+ # Multi-Task Learning as Multi-Objective Optimization
3
+ # Ozan Sener, Vladlen Koltun
4
+ # Neural Information Processing Systems (NeurIPS) 2018
5
+ # https://github.com/intel-isl/MultiObjectiveOptimization
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
13
+ if isinstance(x, torch.Tensor):
14
+ return x.sum().item()
15
+ return np.sum(x)
16
+
17
+
18
+ def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
19
+ if isinstance(x, torch.Tensor):
20
+ return x.detach().cpu().numpy()
21
+ return x
22
+
23
+
24
+ class MinNormSolver:
25
+ MAX_ITER = 250
26
+ STOP_CRIT = 1e-5
27
+
28
+ def _min_norm_element_from2(v1v1, v1v2, v2v2):
29
+ """
30
+ Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
31
+ d is the distance (objective) optimzed
32
+ v1v1 = <x1,x1>
33
+ v1v2 = <x1,x2>
34
+ v2v2 = <x2,x2>
35
+ """
36
+ if v1v2 >= v1v1:
37
+ # Case: Fig 1, third column
38
+ gamma = 0.999
39
+ cost = v1v1
40
+ return gamma, cost
41
+ if v1v2 >= v2v2:
42
+ # Case: Fig 1, first column
43
+ gamma = 0.001
44
+ cost = v2v2
45
+ return gamma, cost
46
+ # Case: Fig 1, second column
47
+ gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
48
+ cost = v2v2 + gamma * (v1v2 - v2v2)
49
+ return gamma, cost
50
+
51
+ def _min_norm_2d(vecs, dps):
52
+ R"""
53
+ Find the minimum norm solution as combination of two points
54
+ This is correct only in 2D
55
+ ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
56
+ """
57
+ dmin = 1e8
58
+ for i in range(len(vecs)):
59
+ for j in range(i + 1, len(vecs)):
60
+ if (i, j) not in dps:
61
+ dps[(i, j)] = 0.0
62
+ for k in range(len(vecs[i])):
63
+ dps[(i, j)] += (
64
+ torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
65
+ )
66
+ dps[(j, i)] = dps[(i, j)]
67
+ if (i, i) not in dps:
68
+ dps[(i, i)] = 0.0
69
+ for k in range(len(vecs[i])):
70
+ dps[(i, i)] += (
71
+ torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
72
+ )
73
+ if (j, j) not in dps:
74
+ dps[(j, j)] = 0.0
75
+ for k in range(len(vecs[i])):
76
+ dps[(j, j)] += (
77
+ torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
78
+ )
79
+ c, d = MinNormSolver._min_norm_element_from2(
80
+ dps[(i, i)], dps[(i, j)], dps[(j, j)]
81
+ )
82
+ if d < dmin:
83
+ dmin = d
84
+ sol = [(i, j), c, d]
85
+ return sol, dps
86
+
87
+ def _projection2simplex(y):
88
+ R"""
89
+ Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
90
+ """
91
+ m = len(y)
92
+ sorted_y = np.flip(np.sort(y), axis=0)
93
+ tmpsum = 0.0
94
+ tmax_f = (np.sum(y) - 1.0) / m
95
+ for i in range(m - 1):
96
+ tmpsum += sorted_y[i]
97
+ tmax = (tmpsum - 1) / (i + 1.0)
98
+ if tmax > sorted_y[i + 1]:
99
+ tmax_f = tmax
100
+ break
101
+ return np.maximum(y - tmax_f, np.zeros(y.shape))
102
+
103
+ def _next_point(cur_val, grad, n):
104
+ proj_grad = grad - (np.sum(grad) / n)
105
+ tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
106
+ tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
107
+
108
+ skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
109
+ t = 1
110
+ if len(tm1[tm1 > 1e-7]) > 0:
111
+ t = np.min(to_numpy(tm1[tm1 > 1e-7]))
112
+ if len(tm2[tm2 > 1e-7]) > 0:
113
+ t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
114
+
115
+ next_point = proj_grad * t + to_numpy(cur_val)
116
+ next_point = MinNormSolver._projection2simplex(next_point)
117
+ return next_point
118
+
119
+ def find_min_norm_element(vecs):
120
+ R"""
121
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
122
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
123
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
124
+ Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
125
+ """
126
+ # Solution lying at the combination of two points
127
+ dps = {}
128
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
129
+
130
+ n = len(vecs)
131
+ sol_vec = np.zeros(n)
132
+ sol_vec[init_sol[0][0]] = init_sol[1]
133
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
134
+
135
+ if n < 3:
136
+ # This is optimal for n=2, so return the solution
137
+ return sol_vec, init_sol[2]
138
+
139
+ iter_count = 0
140
+
141
+ grad_mat = np.zeros((n, n))
142
+ for i in range(n):
143
+ for j in range(n):
144
+ grad_mat[i, j] = dps[(i, j)]
145
+
146
+ while iter_count < MinNormSolver.MAX_ITER:
147
+ grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
148
+ new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
149
+ # Re-compute the inner products for line search
150
+ v1v1 = 0.0
151
+ v1v2 = 0.0
152
+ v2v2 = 0.0
153
+ for i in range(n):
154
+ for j in range(n):
155
+ v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
156
+ v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
157
+ v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
158
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
159
+ new_sol_vec = nc * sol_vec + (1 - nc) * new_point
160
+ change = new_sol_vec - sol_vec
161
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
162
+ return sol_vec, nd
163
+ sol_vec = new_sol_vec
164
+
165
+ def find_min_norm_element_FW(vecs):
166
+ R"""
167
+ Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
168
+ as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
169
+ It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
170
+ Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
171
+ """
172
+ # Solution lying at the combination of two points
173
+ dps = {}
174
+ init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
175
+
176
+ n = len(vecs)
177
+ sol_vec = np.zeros(n)
178
+ sol_vec[init_sol[0][0]] = init_sol[1]
179
+ sol_vec[init_sol[0][1]] = 1 - init_sol[1]
180
+
181
+ if n < 3:
182
+ # This is optimal for n=2, so return the solution
183
+ return sol_vec, init_sol[2]
184
+
185
+ iter_count = 0
186
+
187
+ grad_mat = np.zeros((n, n))
188
+ for i in range(n):
189
+ for j in range(n):
190
+ grad_mat[i, j] = dps[(i, j)]
191
+
192
+ while iter_count < MinNormSolver.MAX_ITER:
193
+ t_iter = np.argmin(np.dot(grad_mat, sol_vec))
194
+
195
+ v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
196
+ v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
197
+ v2v2 = grad_mat[t_iter, t_iter]
198
+
199
+ nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
200
+ new_sol_vec = nc * sol_vec
201
+ new_sol_vec[t_iter] += 1 - nc
202
+
203
+ change = new_sol_vec - sol_vec
204
+ if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
205
+ return sol_vec, nd
206
+ sol_vec = new_sol_vec
207
+
208
+
209
+ def gradient_normalizers(grads, losses, normalization_type):
210
+ gn = {}
211
+ if normalization_type == "l2":
212
+ for t in grads:
213
+ gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
214
+ elif normalization_type == "loss":
215
+ for t in grads:
216
+ gn[t] = losses[t]
217
+ elif normalization_type == "loss+":
218
+ for t in grads:
219
+ gn[t] = losses[t] * np.sqrt(
220
+ np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
221
+ )
222
+ elif normalization_type == "none":
223
+ for t in grads:
224
+ gn[t] = 1.0
225
+ else:
226
+ print("ERROR: Invalid Normalization Type")
227
+ return gn
@@ -0,0 +1,73 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from fusion_bench.utils.parameters import state_dict_to_vector
7
+ from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
8
+
9
+
10
+ def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
11
+ """
12
+ Perform Singular Value Decomposition (SVD) on a tensor.
13
+
14
+ Args:
15
+ w (Tensor): The input tensor.
16
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
17
+
18
+ Returns:
19
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
20
+ """
21
+ u, s, vh = torch.linalg.svd(
22
+ w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
23
+ )
24
+ v = vh.T
25
+ return u, s, v
26
+
27
+
28
+ def svd(
29
+ w: Tensor, full_matrices=True, accelerator=None
30
+ ) -> Tuple[Tensor, Tensor, Tensor]:
31
+ """
32
+ Perform SVD on a tensor, optionally using a specified accelerator.
33
+
34
+ Args:
35
+ w (Tensor): The input tensor.
36
+ full_matrices (bool): Whether to compute the full-sized U and V matrices.
37
+ accelerator (str): The device to perform the computation on.
38
+
39
+ Returns:
40
+ Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
41
+ """
42
+ if accelerator is None:
43
+ return _svd(w, full_matrices=full_matrices)
44
+ original_device = w.device
45
+ w = w.to(accelerator)
46
+ u, s, v = _svd(w)
47
+ return u.to(original_device), s.to(original_device), v.to(original_device)
48
+
49
+
50
+ def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
51
+ return torch.trace(w1.T @ w2)
52
+
53
+
54
+ def is_leaf_module(module: nn.Module) -> bool:
55
+ return len(list(module.children())) == 0
56
+
57
+
58
+ def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
59
+ """
60
+ Get the vector norm of the task model.
61
+
62
+ Args:
63
+ model (nn.Module): The task model.
64
+ pretrained_model (nn.Module): The pretrained model.
65
+
66
+ Returns:
67
+ Tensor: The vector norm of the task model.
68
+ """
69
+ return torch.linalg.norm(
70
+ state_dict_to_vector(
71
+ state_dict_sub(model.state_dict(), pretrained_model.state_dict())
72
+ )
73
+ )
@@ -87,6 +87,7 @@ class OPCMForCLIP(
87
87
  # get the average model
88
88
  with self.profile("loading model"):
89
89
  merged_model = modelpool.load_model(model_names[0])
90
+ assert merged_model is not None, "Failed to load the first model"
90
91
 
91
92
  if self.evaluate_on_every_step:
92
93
  with self.profile("evaluating model"):
@@ -13,8 +13,6 @@ import torch.func
13
13
  from torch import Tensor, nn
14
14
  from torch.nn import functional as F
15
15
 
16
- from fusion_bench.utils import join_list
17
-
18
16
  log = logging.getLogger(__name__)
19
17
 
20
18
 
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
15
15
  state_dict_add,
16
16
  state_dict_binary_mask,
17
17
  state_dict_diff_abs,
18
- state_dict_hadmard_product,
18
+ state_dict_hadamard_product,
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  state_dict_sum,
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
111
111
 
112
112
  with self.profile("compress and retrieve"):
113
113
  for model_name in modelpool.model_names:
114
- retrieved_task_vector = state_dict_hadmard_product(
114
+ retrieved_task_vector = state_dict_hadamard_product(
115
115
  tall_masks[model_name], multi_task_vector
116
116
  )
117
117
  retrieved_state_dict = state_dict_add(