fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +12 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/clip_finetune.py +6 -4
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/dop/__init__.py +1 -0
- fusion_bench/method/dop/dop.py +366 -0
- fusion_bench/method/dop/min_norm_solvers.py +227 -0
- fusion_bench/method/dop/utils.py +73 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
- fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/depth_upscaling.yaml +9 -0
- fusion_bench_config/method/dop/dop.yaml +30 -0
- fusion_bench_config/method/dummy.yaml +6 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
- fusion_bench_config/method/linear/weighted_average.yaml +3 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
- fusion_bench_config/method/model_recombination.yaml +8 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
- fusion_bench_config/method/opcm/opcm.yaml +5 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
- fusion_bench_config/method/opcm/weight_average.yaml +5 -0
- fusion_bench_config/method/simple_average.yaml +9 -0
- fusion_bench_config/method/slerp/slerp.yaml +9 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
- fusion_bench_config/method/task_arithmetic.yaml +9 -0
- fusion_bench_config/method/ties_merging.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Continual Model Merging without Data: Dual Projections for Balancing Stability and Plasticity. NeurIPS, 2025.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
Example:
|
|
6
|
+
|
|
7
|
+
fusion_bench \
|
|
8
|
+
method=dop/dop \
|
|
9
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only \
|
|
10
|
+
taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
import random
|
|
16
|
+
from copy import deepcopy
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Dict, List, Literal, Optional, Tuple, cast
|
|
19
|
+
|
|
20
|
+
import lightning as L
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
from omegaconf import DictConfig
|
|
24
|
+
from torch import Tensor, nn
|
|
25
|
+
from torch.autograd import Variable
|
|
26
|
+
from tqdm.auto import tqdm
|
|
27
|
+
from transformers import CLIPVisionModel
|
|
28
|
+
|
|
29
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
30
|
+
from fusion_bench.method.simple_average import simple_average
|
|
31
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
32
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
33
|
+
from fusion_bench.utils import seed_everything_by_time
|
|
34
|
+
from fusion_bench.utils.json import save_to_json
|
|
35
|
+
|
|
36
|
+
from .min_norm_solvers import MinNormSolver, gradient_normalizers
|
|
37
|
+
from .utils import is_leaf_module, svd
|
|
38
|
+
|
|
39
|
+
log = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@auto_register_config
|
|
43
|
+
class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
seed: Optional[int] = None,
|
|
48
|
+
shuffle_order: bool = False,
|
|
49
|
+
save_on_every_step: bool = True,
|
|
50
|
+
evaluate_on_every_step: bool = False,
|
|
51
|
+
lr: float = 1e-4,
|
|
52
|
+
num_steps: int = 200,
|
|
53
|
+
mgda: bool = True,
|
|
54
|
+
ema: bool = True,
|
|
55
|
+
ema_beta: float = 0.99,
|
|
56
|
+
alpha: float = None,
|
|
57
|
+
svd_epsilon: float = 1.0,
|
|
58
|
+
svd_proj_space: str = "uv",
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
self.lr = lr
|
|
62
|
+
self.num_steps = num_steps
|
|
63
|
+
self.mgda = mgda
|
|
64
|
+
self.ema = ema
|
|
65
|
+
self.ema_beta = ema_beta
|
|
66
|
+
self.alpha = alpha
|
|
67
|
+
self.svd_epsilon = svd_epsilon
|
|
68
|
+
self.svd_proj_space = svd_proj_space
|
|
69
|
+
self.seed = seed
|
|
70
|
+
self.shuffle_order = shuffle_order
|
|
71
|
+
self.save_on_every_step = save_on_every_step
|
|
72
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
73
|
+
|
|
74
|
+
assert (
|
|
75
|
+
self.svd_epsilon >= 0 and self.svd_epsilon <= 1
|
|
76
|
+
), "The svd_epsilon should be in the range of [0, 1]"
|
|
77
|
+
assert (
|
|
78
|
+
self.alpha >= 0 and self.alpha <= 1
|
|
79
|
+
), "The alpha should be in the range of [0, 1]"
|
|
80
|
+
super().__init__(**kwargs)
|
|
81
|
+
|
|
82
|
+
def print_params(self, pretrained_model):
|
|
83
|
+
total_params = 0
|
|
84
|
+
linear_params = 0
|
|
85
|
+
linear_weight_params = 0
|
|
86
|
+
for module_name, module in pretrained_model.named_modules():
|
|
87
|
+
if not is_leaf_module(module):
|
|
88
|
+
continue
|
|
89
|
+
if isinstance(module, nn.Linear):
|
|
90
|
+
linear_params += sum(p.numel() for n, p in module.named_parameters())
|
|
91
|
+
linear_weight_params += sum(
|
|
92
|
+
p.numel() for n, p in module.named_parameters() if "weight" in n
|
|
93
|
+
)
|
|
94
|
+
total_params += sum(p.numel() for p in module.parameters())
|
|
95
|
+
|
|
96
|
+
linear_ratio = linear_params / total_params * 100
|
|
97
|
+
linear_weight_ratio = linear_weight_params / total_params * 100
|
|
98
|
+
print(f"Total Parameters: {total_params}")
|
|
99
|
+
print(f"Linear Parameters: {linear_params}")
|
|
100
|
+
print(f"Linear Weight Parameters: {linear_weight_params}")
|
|
101
|
+
print(f"Linear Ratio: {linear_ratio:.2f}%")
|
|
102
|
+
print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
|
|
103
|
+
|
|
104
|
+
def run(self, modelpool: BaseModelPool):
|
|
105
|
+
if self.seed is not None:
|
|
106
|
+
L.seed_everything(self.seed)
|
|
107
|
+
else:
|
|
108
|
+
seed_everything_by_time(self.fabric)
|
|
109
|
+
|
|
110
|
+
# get the model names, shuffle if needed
|
|
111
|
+
# the model names will be saved to the log directory as `model_names.json`
|
|
112
|
+
model_names = modelpool.model_names
|
|
113
|
+
if self.shuffle_order:
|
|
114
|
+
random.shuffle(model_names)
|
|
115
|
+
if self.log_dir is not None:
|
|
116
|
+
save_to_json(model_names, os.path.join(self.log_dir, "model_names.json"))
|
|
117
|
+
|
|
118
|
+
if self.evaluate_on_every_step:
|
|
119
|
+
"""Configuration for the test datasets"""
|
|
120
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
121
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
122
|
+
|
|
123
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
124
|
+
|
|
125
|
+
merged_model = None
|
|
126
|
+
for model_idx, model_name in enumerate(model_names):
|
|
127
|
+
print(
|
|
128
|
+
f"--------- Optimizing {model_idx + 1}/{len(model_names)}-th with {model_name} ---------"
|
|
129
|
+
)
|
|
130
|
+
if model_idx == 0:
|
|
131
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
132
|
+
else:
|
|
133
|
+
merged_model = self._layer_wise_optimize(
|
|
134
|
+
model_names=["merged", model_name],
|
|
135
|
+
pretrained_model=deepcopy(pretrained_model),
|
|
136
|
+
finetuned_models={
|
|
137
|
+
"merged": merged_model,
|
|
138
|
+
model_name: modelpool.load_model(model_name),
|
|
139
|
+
},
|
|
140
|
+
model_idx=model_idx,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if self.save_on_every_step:
|
|
144
|
+
self.save_merged_model(merged_model, model_idx)
|
|
145
|
+
|
|
146
|
+
if self.evaluate_on_every_step:
|
|
147
|
+
self.taskpool._is_setup = False
|
|
148
|
+
self.taskpool._test_datasets = DictConfig(
|
|
149
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
150
|
+
)
|
|
151
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
152
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
153
|
+
|
|
154
|
+
return merged_model
|
|
155
|
+
|
|
156
|
+
def _layer_wise_optimize(
|
|
157
|
+
self,
|
|
158
|
+
model_names: List[str],
|
|
159
|
+
pretrained_model: nn.Module,
|
|
160
|
+
finetuned_models: Dict[str, nn.Module],
|
|
161
|
+
model_idx: int,
|
|
162
|
+
):
|
|
163
|
+
time_cost = []
|
|
164
|
+
for module_name, module in pretrained_model.named_modules():
|
|
165
|
+
if not is_leaf_module(module):
|
|
166
|
+
continue
|
|
167
|
+
|
|
168
|
+
if isinstance(module, nn.Linear):
|
|
169
|
+
if module.weight.requires_grad:
|
|
170
|
+
import time
|
|
171
|
+
|
|
172
|
+
start_time = time.time()
|
|
173
|
+
merged_weight = self._optimize_weight(
|
|
174
|
+
module.weight,
|
|
175
|
+
{
|
|
176
|
+
model_name: finetuned_models[model_name]
|
|
177
|
+
.get_submodule(module_name)
|
|
178
|
+
.weight
|
|
179
|
+
for model_name in model_names
|
|
180
|
+
},
|
|
181
|
+
module_name,
|
|
182
|
+
model_idx,
|
|
183
|
+
)
|
|
184
|
+
end_time = time.time()
|
|
185
|
+
time_cost.append(end_time - start_time)
|
|
186
|
+
module.weight.data = merged_weight.data
|
|
187
|
+
else:
|
|
188
|
+
module.weight.data = simple_average(
|
|
189
|
+
[
|
|
190
|
+
finetuned_models[model_name]
|
|
191
|
+
.get_submodule(module_name)
|
|
192
|
+
.weight
|
|
193
|
+
for model_name in model_names
|
|
194
|
+
]
|
|
195
|
+
)
|
|
196
|
+
if module.bias is not None:
|
|
197
|
+
module.bias.data = simple_average(
|
|
198
|
+
[
|
|
199
|
+
finetuned_models[model_name].get_submodule(module_name).bias
|
|
200
|
+
for model_name in model_names
|
|
201
|
+
]
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
simple_average(
|
|
205
|
+
[
|
|
206
|
+
finetuned_models[model_name].get_submodule(module_name)
|
|
207
|
+
for model_name in model_names
|
|
208
|
+
],
|
|
209
|
+
base_module=module,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
return pretrained_model
|
|
213
|
+
|
|
214
|
+
def _optimize_weight(
|
|
215
|
+
self,
|
|
216
|
+
pretrained_weight: Tensor,
|
|
217
|
+
finetuned_weights: Dict[str, Tensor],
|
|
218
|
+
module_name: str,
|
|
219
|
+
model_idx: int,
|
|
220
|
+
):
|
|
221
|
+
assert (
|
|
222
|
+
self.fabric.world_size == 1
|
|
223
|
+
), "This algorithm is not currently supported in distributed training"
|
|
224
|
+
|
|
225
|
+
pretrained_weight = self.fabric.to_device(pretrained_weight.detach())
|
|
226
|
+
finetuned_weights = {
|
|
227
|
+
model_name: self.fabric.to_device(finetuned_weight.detach())
|
|
228
|
+
for model_name, finetuned_weight in finetuned_weights.items()
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
merged_weight = self.fabric.to_device(
|
|
232
|
+
nn.Parameter(
|
|
233
|
+
simple_average(
|
|
234
|
+
[
|
|
235
|
+
finetuned_weight.detach()
|
|
236
|
+
for finetuned_weight in finetuned_weights.values()
|
|
237
|
+
]
|
|
238
|
+
),
|
|
239
|
+
requires_grad=True,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Compute SVD of the difference between the finetuned and pretrained weights
|
|
244
|
+
proj_u_dict = {}
|
|
245
|
+
proj_v_dict = {}
|
|
246
|
+
proj_s_dict = {}
|
|
247
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
248
|
+
finetuned_tv = finetuned_weight - pretrained_weight
|
|
249
|
+
u, s, v = svd(finetuned_tv, full_matrices=True)
|
|
250
|
+
epsilon = 1.0 if self.svd_epsilon > 1.0 else self.svd_epsilon
|
|
251
|
+
cumsum_ratio = s.cumsum(dim=0) / s.sum()
|
|
252
|
+
split_rank = torch.searchsorted(cumsum_ratio, epsilon).item()
|
|
253
|
+
u_main = u[:, :split_rank]
|
|
254
|
+
v_main = v[:, :split_rank]
|
|
255
|
+
s_main = s[:split_rank]
|
|
256
|
+
proj_u_dict[i] = u_main
|
|
257
|
+
proj_v_dict[i] = v_main
|
|
258
|
+
proj_s_dict[i] = s_main
|
|
259
|
+
|
|
260
|
+
if self.mgda:
|
|
261
|
+
if self.ema:
|
|
262
|
+
ema_sol = [self.alpha, 1 - self.alpha]
|
|
263
|
+
# This is multiple-gradient descent algorithm (MGDA) optimization
|
|
264
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
265
|
+
all_losses = [[], []]
|
|
266
|
+
all_alphas = [[], []]
|
|
267
|
+
for step_idx in tqdm(
|
|
268
|
+
range(self.num_steps), desc=f"Optimizing {module_name} weight"
|
|
269
|
+
):
|
|
270
|
+
# Scaling the loss functions based on the algorithm choice
|
|
271
|
+
loss_data = {}
|
|
272
|
+
grads = {}
|
|
273
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
274
|
+
proj_u = proj_u_dict[i]
|
|
275
|
+
proj_v = proj_v_dict[i]
|
|
276
|
+
proj_s = proj_s_dict[i]
|
|
277
|
+
delta_tv = merged_weight - finetuned_weight
|
|
278
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
279
|
+
loss_data[i] = float(loss_i.data)
|
|
280
|
+
|
|
281
|
+
all_losses[i].append(float(loss_i.data))
|
|
282
|
+
|
|
283
|
+
optimizer.zero_grad()
|
|
284
|
+
loss_i.backward()
|
|
285
|
+
grads[i] = Variable(
|
|
286
|
+
merged_weight.grad.data.clone(), requires_grad=False
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Normalize all gradients
|
|
290
|
+
gn = gradient_normalizers(
|
|
291
|
+
grads=grads, losses=loss_data, normalization_type="loss"
|
|
292
|
+
)
|
|
293
|
+
for i, _ in enumerate(finetuned_weights.values()):
|
|
294
|
+
grads[i] = grads[i] / float(gn[i])
|
|
295
|
+
|
|
296
|
+
# Frank-Wolfe iteration to compute scales.
|
|
297
|
+
sol, min_norm = MinNormSolver.find_min_norm_element(
|
|
298
|
+
[[grads[i]] for i in range(len(finetuned_weights.values()))]
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if self.ema:
|
|
302
|
+
ema_sol = [
|
|
303
|
+
self.ema_beta * ema_sol[i] + (1 - self.ema_beta) * float(sol[i])
|
|
304
|
+
for i in range(len(sol))
|
|
305
|
+
]
|
|
306
|
+
sol = ema_sol
|
|
307
|
+
all_alphas[0].append(ema_sol[0])
|
|
308
|
+
all_alphas[1].append(ema_sol[1])
|
|
309
|
+
|
|
310
|
+
# Scaled back-propagation
|
|
311
|
+
loss = 0
|
|
312
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
313
|
+
# Comptue gradients of each loss function wrt parameters
|
|
314
|
+
proj_u = proj_u_dict[i]
|
|
315
|
+
proj_v = proj_v_dict[i]
|
|
316
|
+
proj_s = proj_s_dict[i]
|
|
317
|
+
delta_tv = merged_weight - finetuned_weight
|
|
318
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
319
|
+
loss += float(sol[i]) * loss_i
|
|
320
|
+
|
|
321
|
+
optimizer.zero_grad()
|
|
322
|
+
loss.backward()
|
|
323
|
+
optimizer.step()
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
# This is a naive weighted optimization
|
|
327
|
+
optimizer = torch.optim.Adam([merged_weight], lr=self.lr)
|
|
328
|
+
for step_idx in tqdm(
|
|
329
|
+
range(self.num_steps), desc=f"Optimizing {module_name} weight"
|
|
330
|
+
):
|
|
331
|
+
loss = 0
|
|
332
|
+
for i, finetuned_weight in enumerate(finetuned_weights.values()):
|
|
333
|
+
proj_u = proj_u_dict[i]
|
|
334
|
+
proj_v = proj_v_dict[i]
|
|
335
|
+
proj_s = proj_s_dict[i]
|
|
336
|
+
delta_tv = merged_weight - finetuned_weight
|
|
337
|
+
loss_i = self.cal_loss_i(delta_tv, proj_s, proj_u, proj_v)
|
|
338
|
+
loss += self.alpha * loss_i if i == 0 else (1 - self.alpha) * loss_i
|
|
339
|
+
|
|
340
|
+
optimizer.zero_grad()
|
|
341
|
+
loss.backward()
|
|
342
|
+
optimizer.step()
|
|
343
|
+
|
|
344
|
+
return merged_weight.detach().cpu()
|
|
345
|
+
|
|
346
|
+
def cal_loss_i(self, delta_tv, proj_s, proj_u, proj_v):
|
|
347
|
+
proj_delta_1 = torch.diag(proj_s) @ proj_u.T @ delta_tv
|
|
348
|
+
proj_delta_2 = delta_tv @ proj_v @ torch.diag(proj_s)
|
|
349
|
+
loss_i_u = torch.linalg.matrix_norm(proj_delta_1, ord="fro") ** 2
|
|
350
|
+
loss_i_v = torch.linalg.matrix_norm(proj_delta_2, ord="fro") ** 2
|
|
351
|
+
if self.svd_proj_space == "uv":
|
|
352
|
+
loss_i = loss_i_u + loss_i_v
|
|
353
|
+
elif self.svd_proj_space == "u":
|
|
354
|
+
loss_i = loss_i_u
|
|
355
|
+
elif self.svd_proj_space == "v":
|
|
356
|
+
loss_i = loss_i_v
|
|
357
|
+
else:
|
|
358
|
+
raise ValueError("Invalid svd_proj_space")
|
|
359
|
+
|
|
360
|
+
return loss_i
|
|
361
|
+
|
|
362
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
363
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
364
|
+
merged_model.save_pretrained(
|
|
365
|
+
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
366
|
+
)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# This code is from
|
|
2
|
+
# Multi-Task Learning as Multi-Objective Optimization
|
|
3
|
+
# Ozan Sener, Vladlen Koltun
|
|
4
|
+
# Neural Information Processing Systems (NeurIPS) 2018
|
|
5
|
+
# https://github.com/intel-isl/MultiObjectiveOptimization
|
|
6
|
+
from typing import Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def np_sum(x: Union[torch.Tensor, np.ndarray]) -> float:
|
|
13
|
+
if isinstance(x, torch.Tensor):
|
|
14
|
+
return x.sum().item()
|
|
15
|
+
return np.sum(x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def to_numpy(x: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
|
|
19
|
+
if isinstance(x, torch.Tensor):
|
|
20
|
+
return x.detach().cpu().numpy()
|
|
21
|
+
return x
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class MinNormSolver:
|
|
25
|
+
MAX_ITER = 250
|
|
26
|
+
STOP_CRIT = 1e-5
|
|
27
|
+
|
|
28
|
+
def _min_norm_element_from2(v1v1, v1v2, v2v2):
|
|
29
|
+
"""
|
|
30
|
+
Analytical solution for min_{c} |cx_1 + (1-c)x_2|_2^2
|
|
31
|
+
d is the distance (objective) optimzed
|
|
32
|
+
v1v1 = <x1,x1>
|
|
33
|
+
v1v2 = <x1,x2>
|
|
34
|
+
v2v2 = <x2,x2>
|
|
35
|
+
"""
|
|
36
|
+
if v1v2 >= v1v1:
|
|
37
|
+
# Case: Fig 1, third column
|
|
38
|
+
gamma = 0.999
|
|
39
|
+
cost = v1v1
|
|
40
|
+
return gamma, cost
|
|
41
|
+
if v1v2 >= v2v2:
|
|
42
|
+
# Case: Fig 1, first column
|
|
43
|
+
gamma = 0.001
|
|
44
|
+
cost = v2v2
|
|
45
|
+
return gamma, cost
|
|
46
|
+
# Case: Fig 1, second column
|
|
47
|
+
gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2 * v1v2))
|
|
48
|
+
cost = v2v2 + gamma * (v1v2 - v2v2)
|
|
49
|
+
return gamma, cost
|
|
50
|
+
|
|
51
|
+
def _min_norm_2d(vecs, dps):
|
|
52
|
+
R"""
|
|
53
|
+
Find the minimum norm solution as combination of two points
|
|
54
|
+
This is correct only in 2D
|
|
55
|
+
ie. min_c |\sum c_i x_i|_2^2 st. \sum c_i = 1 , 1 >= c_1 >= 0 for all i, c_i + c_j = 1.0 for some i, j
|
|
56
|
+
"""
|
|
57
|
+
dmin = 1e8
|
|
58
|
+
for i in range(len(vecs)):
|
|
59
|
+
for j in range(i + 1, len(vecs)):
|
|
60
|
+
if (i, j) not in dps:
|
|
61
|
+
dps[(i, j)] = 0.0
|
|
62
|
+
for k in range(len(vecs[i])):
|
|
63
|
+
dps[(i, j)] += (
|
|
64
|
+
torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
|
|
65
|
+
)
|
|
66
|
+
dps[(j, i)] = dps[(i, j)]
|
|
67
|
+
if (i, i) not in dps:
|
|
68
|
+
dps[(i, i)] = 0.0
|
|
69
|
+
for k in range(len(vecs[i])):
|
|
70
|
+
dps[(i, i)] += (
|
|
71
|
+
torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
|
|
72
|
+
)
|
|
73
|
+
if (j, j) not in dps:
|
|
74
|
+
dps[(j, j)] = 0.0
|
|
75
|
+
for k in range(len(vecs[i])):
|
|
76
|
+
dps[(j, j)] += (
|
|
77
|
+
torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
|
|
78
|
+
)
|
|
79
|
+
c, d = MinNormSolver._min_norm_element_from2(
|
|
80
|
+
dps[(i, i)], dps[(i, j)], dps[(j, j)]
|
|
81
|
+
)
|
|
82
|
+
if d < dmin:
|
|
83
|
+
dmin = d
|
|
84
|
+
sol = [(i, j), c, d]
|
|
85
|
+
return sol, dps
|
|
86
|
+
|
|
87
|
+
def _projection2simplex(y):
|
|
88
|
+
R"""
|
|
89
|
+
Given y, it solves argmin_z |y-z|_2 st \sum z = 1 , 1 >= z_i >= 0 for all i
|
|
90
|
+
"""
|
|
91
|
+
m = len(y)
|
|
92
|
+
sorted_y = np.flip(np.sort(y), axis=0)
|
|
93
|
+
tmpsum = 0.0
|
|
94
|
+
tmax_f = (np.sum(y) - 1.0) / m
|
|
95
|
+
for i in range(m - 1):
|
|
96
|
+
tmpsum += sorted_y[i]
|
|
97
|
+
tmax = (tmpsum - 1) / (i + 1.0)
|
|
98
|
+
if tmax > sorted_y[i + 1]:
|
|
99
|
+
tmax_f = tmax
|
|
100
|
+
break
|
|
101
|
+
return np.maximum(y - tmax_f, np.zeros(y.shape))
|
|
102
|
+
|
|
103
|
+
def _next_point(cur_val, grad, n):
|
|
104
|
+
proj_grad = grad - (np.sum(grad) / n)
|
|
105
|
+
tm1 = -1.0 * cur_val[proj_grad < 0] / proj_grad[proj_grad < 0]
|
|
106
|
+
tm2 = (1.0 - cur_val[proj_grad > 0]) / (proj_grad[proj_grad > 0])
|
|
107
|
+
|
|
108
|
+
skippers = np_sum(tm1 < 1e-7) + np_sum(tm2 < 1e-7)
|
|
109
|
+
t = 1
|
|
110
|
+
if len(tm1[tm1 > 1e-7]) > 0:
|
|
111
|
+
t = np.min(to_numpy(tm1[tm1 > 1e-7]))
|
|
112
|
+
if len(tm2[tm2 > 1e-7]) > 0:
|
|
113
|
+
t = min(t, np.min(to_numpy(tm2[tm2 > 1e-7])))
|
|
114
|
+
|
|
115
|
+
next_point = proj_grad * t + to_numpy(cur_val)
|
|
116
|
+
next_point = MinNormSolver._projection2simplex(next_point)
|
|
117
|
+
return next_point
|
|
118
|
+
|
|
119
|
+
def find_min_norm_element(vecs):
|
|
120
|
+
R"""
|
|
121
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
122
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
123
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
124
|
+
Hence, we find the best 2-task solution, and then run the projected gradient descent until convergence
|
|
125
|
+
"""
|
|
126
|
+
# Solution lying at the combination of two points
|
|
127
|
+
dps = {}
|
|
128
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
129
|
+
|
|
130
|
+
n = len(vecs)
|
|
131
|
+
sol_vec = np.zeros(n)
|
|
132
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
133
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
134
|
+
|
|
135
|
+
if n < 3:
|
|
136
|
+
# This is optimal for n=2, so return the solution
|
|
137
|
+
return sol_vec, init_sol[2]
|
|
138
|
+
|
|
139
|
+
iter_count = 0
|
|
140
|
+
|
|
141
|
+
grad_mat = np.zeros((n, n))
|
|
142
|
+
for i in range(n):
|
|
143
|
+
for j in range(n):
|
|
144
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
145
|
+
|
|
146
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
147
|
+
grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
|
|
148
|
+
new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)
|
|
149
|
+
# Re-compute the inner products for line search
|
|
150
|
+
v1v1 = 0.0
|
|
151
|
+
v1v2 = 0.0
|
|
152
|
+
v2v2 = 0.0
|
|
153
|
+
for i in range(n):
|
|
154
|
+
for j in range(n):
|
|
155
|
+
v1v1 += sol_vec[i] * sol_vec[j] * dps[(i, j)]
|
|
156
|
+
v1v2 += sol_vec[i] * new_point[j] * dps[(i, j)]
|
|
157
|
+
v2v2 += new_point[i] * new_point[j] * dps[(i, j)]
|
|
158
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
159
|
+
new_sol_vec = nc * sol_vec + (1 - nc) * new_point
|
|
160
|
+
change = new_sol_vec - sol_vec
|
|
161
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
162
|
+
return sol_vec, nd
|
|
163
|
+
sol_vec = new_sol_vec
|
|
164
|
+
|
|
165
|
+
def find_min_norm_element_FW(vecs):
|
|
166
|
+
R"""
|
|
167
|
+
Given a list of vectors (vecs), this method finds the minimum norm element in the convex hull
|
|
168
|
+
as min |u|_2 st. u = \sum c_i vecs[i] and \sum c_i = 1.
|
|
169
|
+
It is quite geometric, and the main idea is the fact that if d_{ij} = min |u|_2 st u = c x_i + (1-c) x_j; the solution lies in (0, d_{i,j})
|
|
170
|
+
Hence, we find the best 2-task solution, and then run the Frank Wolfe until convergence
|
|
171
|
+
"""
|
|
172
|
+
# Solution lying at the combination of two points
|
|
173
|
+
dps = {}
|
|
174
|
+
init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
|
|
175
|
+
|
|
176
|
+
n = len(vecs)
|
|
177
|
+
sol_vec = np.zeros(n)
|
|
178
|
+
sol_vec[init_sol[0][0]] = init_sol[1]
|
|
179
|
+
sol_vec[init_sol[0][1]] = 1 - init_sol[1]
|
|
180
|
+
|
|
181
|
+
if n < 3:
|
|
182
|
+
# This is optimal for n=2, so return the solution
|
|
183
|
+
return sol_vec, init_sol[2]
|
|
184
|
+
|
|
185
|
+
iter_count = 0
|
|
186
|
+
|
|
187
|
+
grad_mat = np.zeros((n, n))
|
|
188
|
+
for i in range(n):
|
|
189
|
+
for j in range(n):
|
|
190
|
+
grad_mat[i, j] = dps[(i, j)]
|
|
191
|
+
|
|
192
|
+
while iter_count < MinNormSolver.MAX_ITER:
|
|
193
|
+
t_iter = np.argmin(np.dot(grad_mat, sol_vec))
|
|
194
|
+
|
|
195
|
+
v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
|
|
196
|
+
v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
|
|
197
|
+
v2v2 = grad_mat[t_iter, t_iter]
|
|
198
|
+
|
|
199
|
+
nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
|
|
200
|
+
new_sol_vec = nc * sol_vec
|
|
201
|
+
new_sol_vec[t_iter] += 1 - nc
|
|
202
|
+
|
|
203
|
+
change = new_sol_vec - sol_vec
|
|
204
|
+
if np_sum(np.abs(change)) < MinNormSolver.STOP_CRIT:
|
|
205
|
+
return sol_vec, nd
|
|
206
|
+
sol_vec = new_sol_vec
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def gradient_normalizers(grads, losses, normalization_type):
|
|
210
|
+
gn = {}
|
|
211
|
+
if normalization_type == "l2":
|
|
212
|
+
for t in grads:
|
|
213
|
+
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
|
|
214
|
+
elif normalization_type == "loss":
|
|
215
|
+
for t in grads:
|
|
216
|
+
gn[t] = losses[t]
|
|
217
|
+
elif normalization_type == "loss+":
|
|
218
|
+
for t in grads:
|
|
219
|
+
gn[t] = losses[t] * np.sqrt(
|
|
220
|
+
np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]])
|
|
221
|
+
)
|
|
222
|
+
elif normalization_type == "none":
|
|
223
|
+
for t in grads:
|
|
224
|
+
gn[t] = 1.0
|
|
225
|
+
else:
|
|
226
|
+
print("ERROR: Invalid Normalization Type")
|
|
227
|
+
return gn
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor, nn
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
7
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
11
|
+
"""
|
|
12
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
w (Tensor): The input tensor.
|
|
16
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
20
|
+
"""
|
|
21
|
+
u, s, vh = torch.linalg.svd(
|
|
22
|
+
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
23
|
+
)
|
|
24
|
+
v = vh.T
|
|
25
|
+
return u, s, v
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def svd(
|
|
29
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
30
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
31
|
+
"""
|
|
32
|
+
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
w (Tensor): The input tensor.
|
|
36
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
37
|
+
accelerator (str): The device to perform the computation on.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
41
|
+
"""
|
|
42
|
+
if accelerator is None:
|
|
43
|
+
return _svd(w, full_matrices=full_matrices)
|
|
44
|
+
original_device = w.device
|
|
45
|
+
w = w.to(accelerator)
|
|
46
|
+
u, s, v = _svd(w)
|
|
47
|
+
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def frobenius_inner_product(w1: Tensor, w2: Tensor) -> Tensor:
|
|
51
|
+
return torch.trace(w1.T @ w2)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
55
|
+
return len(list(module.children())) == 0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def get_task_vector_norm(model: nn.Module, pretrained_model: nn.Module) -> Tensor:
|
|
59
|
+
"""
|
|
60
|
+
Get the vector norm of the task model.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
model (nn.Module): The task model.
|
|
64
|
+
pretrained_model (nn.Module): The pretrained model.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Tensor: The vector norm of the task model.
|
|
68
|
+
"""
|
|
69
|
+
return torch.linalg.norm(
|
|
70
|
+
state_dict_to_vector(
|
|
71
|
+
state_dict_sub(model.state_dict(), pretrained_model.state_dict())
|
|
72
|
+
)
|
|
73
|
+
)
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -87,6 +87,7 @@ class OPCMForCLIP(
|
|
|
87
87
|
# get the average model
|
|
88
88
|
with self.profile("loading model"):
|
|
89
89
|
merged_model = modelpool.load_model(model_names[0])
|
|
90
|
+
assert merged_model is not None, "Failed to load the first model"
|
|
90
91
|
|
|
91
92
|
if self.evaluate_on_every_step:
|
|
92
93
|
with self.profile("evaluating model"):
|
|
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
15
15
|
state_dict_add,
|
|
16
16
|
state_dict_binary_mask,
|
|
17
17
|
state_dict_diff_abs,
|
|
18
|
-
|
|
18
|
+
state_dict_hadamard_product,
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
state_dict_sum,
|
|
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
111
111
|
|
|
112
112
|
with self.profile("compress and retrieve"):
|
|
113
113
|
for model_name in modelpool.model_names:
|
|
114
|
-
retrieved_task_vector =
|
|
114
|
+
retrieved_task_vector = state_dict_hadamard_product(
|
|
115
115
|
tall_masks[model_name], multi_task_vector
|
|
116
116
|
)
|
|
117
117
|
retrieved_state_dict = state_dict_add(
|