fusion-bench 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__main__.py +4 -0
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/method/__init__.py +26 -4
- fusion_bench/method/classification/__init__.py +1 -0
- fusion_bench/method/classification/clip_finetune.py +1 -3
- fusion_bench/method/classification/continual_clip_finetune.py +297 -0
- fusion_bench/method/dare/__init__.py +1 -0
- fusion_bench/method/dare/task_arithmetic.py +14 -7
- fusion_bench/method/dare/ties_merging.py +100 -0
- fusion_bench/method/isotropic_merging/__init__.py +15 -0
- fusion_bench/method/isotropic_merging/iso.py +114 -0
- fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
- fusion_bench/method/opcm/__init__.py +4 -0
- fusion_bench/method/opcm/opcm.py +277 -0
- fusion_bench/method/opcm/task_arithmetic.py +115 -0
- fusion_bench/method/opcm/ties_merging.py +156 -0
- fusion_bench/method/opcm/utils.py +73 -0
- fusion_bench/method/opcm/weight_average.py +120 -0
- fusion_bench/method/slerp/slerp.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +22 -2
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +91 -93
- fusion_bench/method/ties_merging/ties_merging.py +10 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +22 -0
- fusion_bench/mixins/clip_classification.py +4 -1
- fusion_bench/programs/fabric_fusion_program.py +22 -11
- fusion_bench/scripts/cli.py +1 -0
- fusion_bench/taskpool/base_pool.py +1 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -7
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/dict.py +43 -0
- fusion_bench/utils/expr.py +90 -0
- fusion_bench/utils/fabric.py +17 -0
- fusion_bench/utils/instantiate.py +7 -1
- fusion_bench/utils/json.py +30 -0
- fusion_bench/utils/parameters.py +27 -7
- fusion_bench/utils/path.py +15 -0
- fusion_bench/utils/plot/color_data.py +1726 -0
- fusion_bench/utils/rich_utils.py +15 -0
- fusion_bench/utils/set.py +8 -0
- fusion_bench/utils/tensorboard.py +51 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/METADATA +17 -18
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/RECORD +58 -29
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +28 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +26 -0
- fusion_bench_config/method/clip_finetune.yaml +2 -2
- fusion_bench_config/method/dare/ties_merging.yaml +15 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
- fusion_bench_config/method/opcm/opcm.yaml +12 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +12 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +18 -0
- fusion_bench_config/method/opcm/weight_average.yaml +10 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +18 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.8.dist-info → fusion_bench-0.2.10.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils import timeit_context
|
|
7
|
+
from fusion_bench.utils.type import StateDictType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def iso_c(
|
|
11
|
+
task_vectors: List[StateDictType],
|
|
12
|
+
accelerator="cuda",
|
|
13
|
+
exclude_keys: List[str] = None,
|
|
14
|
+
) -> StateDictType:
|
|
15
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
16
|
+
|
|
17
|
+
with torch.no_grad(), timeit_context("ISO-C Merging"):
|
|
18
|
+
new_vector = {}
|
|
19
|
+
for key in task_vectors[0]:
|
|
20
|
+
print(f"Merging {key}...")
|
|
21
|
+
original_device = task_vectors[0][key].device
|
|
22
|
+
tvs = [
|
|
23
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
24
|
+
for task_vector in task_vectors
|
|
25
|
+
]
|
|
26
|
+
num_tvs = len(tvs)
|
|
27
|
+
new_vector[key] = sum(tvs) / num_tvs
|
|
28
|
+
del tvs # free memory
|
|
29
|
+
|
|
30
|
+
if len(task_vectors[0][key].shape) == 2 and key not in exclude_keys:
|
|
31
|
+
# if the key is a 2D matrix, we need to merge the task vectors in the common space
|
|
32
|
+
new_vector[key] *= num_tvs
|
|
33
|
+
U, S, V = torch.linalg.svd(new_vector[key], full_matrices=False)
|
|
34
|
+
S_mean = torch.ones_like(S) * S.mean()
|
|
35
|
+
|
|
36
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
37
|
+
(
|
|
38
|
+
U,
|
|
39
|
+
torch.diag(S_mean),
|
|
40
|
+
V,
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
new_vector[key] = new_vector[key].to(
|
|
44
|
+
device=original_device, non_blocking=True
|
|
45
|
+
)
|
|
46
|
+
return new_vector
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@torch.no_grad()
|
|
50
|
+
def iso_cts(
|
|
51
|
+
task_vectors: List[StateDictType],
|
|
52
|
+
common_space_fraction: float,
|
|
53
|
+
accelerator: str = "cuda",
|
|
54
|
+
exclude_keys: List[str] = None,
|
|
55
|
+
):
|
|
56
|
+
exclude_keys = [] if exclude_keys is None else exclude_keys
|
|
57
|
+
new_vector = {}
|
|
58
|
+
|
|
59
|
+
print("ISO-CTS Merging")
|
|
60
|
+
for key in task_vectors[0]:
|
|
61
|
+
shape_ = task_vectors[0][key].shape
|
|
62
|
+
original_device = task_vectors[0][key].device
|
|
63
|
+
is_2d_matrix = (len(shape_) == 2) and (key not in exclude_keys)
|
|
64
|
+
if not is_2d_matrix:
|
|
65
|
+
print(f"Combining by avg {key}...")
|
|
66
|
+
for i, task_vector in enumerate(task_vectors):
|
|
67
|
+
vec = task_vector[key].to(device=accelerator, non_blocking=True)
|
|
68
|
+
if i == 0:
|
|
69
|
+
new_vector[key] = vec.clone()
|
|
70
|
+
else:
|
|
71
|
+
new_vector[key] += (vec - new_vector[key]) / (i + 1)
|
|
72
|
+
|
|
73
|
+
# move the new vector to the original device
|
|
74
|
+
new_vector[key] = new_vector[key].to(
|
|
75
|
+
device=original_device, non_blocking=True
|
|
76
|
+
)
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
print(f"Computing common space using sum for {key}...")
|
|
80
|
+
combined_w = sum(
|
|
81
|
+
[
|
|
82
|
+
task_vector[key].to(device=accelerator, non_blocking=True)
|
|
83
|
+
for task_vector in task_vectors
|
|
84
|
+
]
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
### Calculate the common space size (making sure that task specific space is equally divisible) ###
|
|
88
|
+
common_space_index_s = int(min(shape_) * common_space_fraction)
|
|
89
|
+
_task_specific_total_space_index_s = round(
|
|
90
|
+
(min(shape_) - common_space_index_s) / len(task_vectors)
|
|
91
|
+
) * len(task_vectors)
|
|
92
|
+
common_space_index_s = min(shape_) - _task_specific_total_space_index_s
|
|
93
|
+
|
|
94
|
+
u, s, v = torch.linalg.svd(combined_w, full_matrices=False)
|
|
95
|
+
common_space_u = u[:, :common_space_index_s]
|
|
96
|
+
common_space_s = s[:common_space_index_s]
|
|
97
|
+
common_space_v = v[:common_space_index_s, :]
|
|
98
|
+
###################################################################
|
|
99
|
+
|
|
100
|
+
### Calculate task specific space ###
|
|
101
|
+
n_dims_per_task = int((min(shape_) - common_space_index_s) / len(task_vectors))
|
|
102
|
+
for i, task_vector in enumerate(task_vectors):
|
|
103
|
+
w = task_vector[key].to(device=accelerator)
|
|
104
|
+
|
|
105
|
+
# calculate the projection onto task specific space to remove the common space
|
|
106
|
+
w_ts = w - common_space_u @ common_space_u.T @ w
|
|
107
|
+
u_ts, s_ts, v_ts = torch.linalg.svd(w_ts, full_matrices=False)
|
|
108
|
+
|
|
109
|
+
if i == 0:
|
|
110
|
+
combined_space_u = torch.zeros_like(u_ts, device=accelerator)
|
|
111
|
+
combined_space_s = torch.zeros_like(s_ts, device=accelerator)
|
|
112
|
+
combined_space_v = torch.zeros_like(v_ts, device=accelerator)
|
|
113
|
+
|
|
114
|
+
combined_space_u[:, i * n_dims_per_task : (i + 1) * n_dims_per_task] = u_ts[
|
|
115
|
+
:, :n_dims_per_task
|
|
116
|
+
]
|
|
117
|
+
combined_space_s[i * n_dims_per_task : (i + 1) * n_dims_per_task] = s_ts[
|
|
118
|
+
:n_dims_per_task
|
|
119
|
+
]
|
|
120
|
+
combined_space_v[i * n_dims_per_task : (i + 1) * n_dims_per_task, :] = v_ts[
|
|
121
|
+
:n_dims_per_task, :
|
|
122
|
+
]
|
|
123
|
+
###################################################################
|
|
124
|
+
|
|
125
|
+
combined_space_u[
|
|
126
|
+
:,
|
|
127
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
128
|
+
+ common_space_index_s,
|
|
129
|
+
] = common_space_u
|
|
130
|
+
combined_space_s[
|
|
131
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
132
|
+
+ common_space_index_s
|
|
133
|
+
] = common_space_s
|
|
134
|
+
combined_space_v[
|
|
135
|
+
len(task_vectors) * n_dims_per_task : len(task_vectors) * n_dims_per_task
|
|
136
|
+
+ common_space_index_s,
|
|
137
|
+
:,
|
|
138
|
+
] = common_space_v
|
|
139
|
+
|
|
140
|
+
### Orthogonalize combined_space_u and combined_space_v ###
|
|
141
|
+
u_combined_space_u, s_combined_space_u, v_combined_space_u = torch.linalg.svd(
|
|
142
|
+
combined_space_u, full_matrices=False
|
|
143
|
+
)
|
|
144
|
+
u_combined_space_v, s_combined_space_v, v_combined_space_v = torch.linalg.svd(
|
|
145
|
+
combined_space_v, full_matrices=False
|
|
146
|
+
)
|
|
147
|
+
combined_space_u = u_combined_space_u @ v_combined_space_u
|
|
148
|
+
combined_space_v = u_combined_space_v @ v_combined_space_v
|
|
149
|
+
###################################################################
|
|
150
|
+
|
|
151
|
+
combined_space_s = torch.ones_like(combined_space_s) * combined_space_s.mean()
|
|
152
|
+
|
|
153
|
+
new_vector[key] = torch.linalg.multi_dot(
|
|
154
|
+
(
|
|
155
|
+
combined_space_u,
|
|
156
|
+
torch.diag(combined_space_s),
|
|
157
|
+
combined_space_v,
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
new_vector[key] = new_vector[key].to(device=original_device, non_blocking=True)
|
|
161
|
+
|
|
162
|
+
return new_vector
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def check_parameterNamesMatch(checkpoints):
|
|
166
|
+
parameter_names = set(checkpoints[0].keys())
|
|
167
|
+
|
|
168
|
+
if len(checkpoints) >= 2:
|
|
169
|
+
# raise ValueError("Number of models is less than 2.")
|
|
170
|
+
for checkpoint in checkpoints[1:]:
|
|
171
|
+
current_parameterNames = set(checkpoint.keys())
|
|
172
|
+
if current_parameterNames != parameter_names:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"Differing parameter names in models. "
|
|
175
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
176
|
+
)
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
from transformers import CLIPVisionModel
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
|
+
from fusion_bench.utils import instantiate
|
|
21
|
+
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
22
|
+
from fusion_bench.utils.parameters import state_dict_to_vector
|
|
23
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
24
|
+
|
|
25
|
+
from .utils import frobenius_inner_product, get_task_vector_norm, is_leaf_module, svd
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class OPCMForCLIP(
|
|
32
|
+
BaseAlgorithm,
|
|
33
|
+
LightningFabricMixin,
|
|
34
|
+
):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
alpha: float,
|
|
38
|
+
shuffle_order: bool = True,
|
|
39
|
+
seed: Optional[int] = None,
|
|
40
|
+
save_on_every_step: bool = True,
|
|
41
|
+
evaluate_on_every_step: bool = False,
|
|
42
|
+
**kwargs,
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Continual Model Merging via SVD Projection.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
alpha (float): the scaling factor for the SVD projection.
|
|
49
|
+
shuffle_order (bool): whether to shuffle the order of the models.
|
|
50
|
+
seed (Optional[int]): the seed to use.
|
|
51
|
+
save_on_every_step (bool): whether to save the merged model on every step.
|
|
52
|
+
evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
|
|
53
|
+
"""
|
|
54
|
+
self.alpha = alpha
|
|
55
|
+
self.shuffle_order = shuffle_order
|
|
56
|
+
self.seed = seed
|
|
57
|
+
self.save_on_every_step = save_on_every_step
|
|
58
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
59
|
+
super().__init__(**kwargs)
|
|
60
|
+
|
|
61
|
+
@torch.no_grad()
|
|
62
|
+
def run(self, modelpool: BaseModelPool):
|
|
63
|
+
if self.seed is not None:
|
|
64
|
+
L.seed_everything(self.seed)
|
|
65
|
+
accelerator = self.fabric.device
|
|
66
|
+
|
|
67
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
68
|
+
|
|
69
|
+
model_names = modelpool.model_names
|
|
70
|
+
if self.shuffle_order:
|
|
71
|
+
random.shuffle(model_names)
|
|
72
|
+
|
|
73
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
74
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
75
|
+
"""Configuration for the test datasets"""
|
|
76
|
+
|
|
77
|
+
# log the model names
|
|
78
|
+
if self.log_dir is not None:
|
|
79
|
+
save_to_json(model_names, Path(self.log_dir) / "model_names.json")
|
|
80
|
+
tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
|
|
81
|
+
tensorboard_summarywriter.add_text(
|
|
82
|
+
"global/model_names", str(model_names), global_step=0
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# get the average model
|
|
86
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
87
|
+
|
|
88
|
+
if self.evaluate_on_every_step:
|
|
89
|
+
self.taskpool._is_setup = False
|
|
90
|
+
self.taskpool._test_datasets = DictConfig(
|
|
91
|
+
{model_names[0]: self._test_datasets[model_names[0]]}
|
|
92
|
+
)
|
|
93
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
94
|
+
save_to_json(report, Path(self.log_dir) / "report_0.json")
|
|
95
|
+
|
|
96
|
+
self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
97
|
+
self.all_task_vector_norm = [self.avg_task_vector_norm]
|
|
98
|
+
self.fabric.log("model/task_vector_norm", self.avg_task_vector_norm, step=0)
|
|
99
|
+
self.fabric.log("model/avg_task_vector_norm", self.avg_task_vector_norm, step=0)
|
|
100
|
+
self.fabric.log(
|
|
101
|
+
"model/merged_task_vector_norm", self.avg_task_vector_norm, step=0
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.previous_lambda_t = 1
|
|
105
|
+
self.lambda_t = None
|
|
106
|
+
self.fabric.log("model/lambda_t", self.previous_lambda_t, step=0)
|
|
107
|
+
self.fabric.log("empirical/lambda_t", 1, step=0)
|
|
108
|
+
|
|
109
|
+
if self.save_on_every_step:
|
|
110
|
+
self.save_merged_model(merged_model, 0)
|
|
111
|
+
|
|
112
|
+
for model_idx, model_name in tqdm(
|
|
113
|
+
enumerate(model_names[1:]), desc="Processing models"
|
|
114
|
+
):
|
|
115
|
+
model_idx += 1
|
|
116
|
+
task_model = modelpool.load_model(model_name)
|
|
117
|
+
|
|
118
|
+
self.all_task_vector_norm.append(
|
|
119
|
+
get_task_vector_norm(task_model, pretrained_model)
|
|
120
|
+
)
|
|
121
|
+
self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
|
|
122
|
+
self.fabric.log(
|
|
123
|
+
"model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
|
|
124
|
+
)
|
|
125
|
+
self.fabric.log(
|
|
126
|
+
"model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.lambda_t = 1 # temporary value
|
|
130
|
+
|
|
131
|
+
for module_name, module in tqdm(
|
|
132
|
+
list(merged_model.named_modules()),
|
|
133
|
+
desc=f"Processing {model_name}",
|
|
134
|
+
leave=False,
|
|
135
|
+
):
|
|
136
|
+
if not is_leaf_module(module):
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
if isinstance(module, nn.Linear):
|
|
140
|
+
module.weight.data = self.merge_linear_weights(
|
|
141
|
+
module.weight,
|
|
142
|
+
pretrained_model.get_submodule(module_name).weight,
|
|
143
|
+
task_model.get_submodule(module_name).weight,
|
|
144
|
+
param_name=".".join([module_name, "weight"]),
|
|
145
|
+
alpha=self.alpha,
|
|
146
|
+
accelerator=accelerator,
|
|
147
|
+
)
|
|
148
|
+
if module.bias is not None:
|
|
149
|
+
module.bias.data = self.merge_other_parameters(
|
|
150
|
+
module.bias,
|
|
151
|
+
pretrained_model.get_submodule(module_name).bias,
|
|
152
|
+
task_model.get_submodule(module_name).bias,
|
|
153
|
+
param_name=".".join([module_name, "bias"]),
|
|
154
|
+
accelerator=accelerator,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
for param_name, param in module.named_parameters():
|
|
158
|
+
param.data = self.merge_other_parameters(
|
|
159
|
+
merged_W=param,
|
|
160
|
+
pretrained_W=pretrained_model.get_submodule(
|
|
161
|
+
module_name
|
|
162
|
+
).get_parameter(param_name),
|
|
163
|
+
task_W=task_model.get_submodule(module_name).get_parameter(
|
|
164
|
+
param_name
|
|
165
|
+
),
|
|
166
|
+
param_name=".".join([module_name, param_name]),
|
|
167
|
+
accelerator=accelerator,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
171
|
+
self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
|
|
172
|
+
for param_name, param in merged_model.named_parameters():
|
|
173
|
+
param.data = pretrained_model.get_parameter(param_name) + (
|
|
174
|
+
param - pretrained_model.get_parameter(param_name)
|
|
175
|
+
) * (self.avg_task_vector_norm / task_vector_norm)
|
|
176
|
+
self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
|
|
177
|
+
self.fabric.log(
|
|
178
|
+
"empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
|
|
179
|
+
)
|
|
180
|
+
self.previous_lambda_t = self.lambda_t
|
|
181
|
+
self.lambda_t = None
|
|
182
|
+
|
|
183
|
+
self.fabric.log(
|
|
184
|
+
"model/merged_task_vector_norm",
|
|
185
|
+
get_task_vector_norm(merged_model, pretrained_model),
|
|
186
|
+
step=model_idx,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if self.save_on_every_step:
|
|
190
|
+
self.save_merged_model(merged_model, model_idx)
|
|
191
|
+
|
|
192
|
+
if self.evaluate_on_every_step:
|
|
193
|
+
self.taskpool._is_setup = False
|
|
194
|
+
self.taskpool._test_datasets = DictConfig(
|
|
195
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
196
|
+
)
|
|
197
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
198
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
199
|
+
|
|
200
|
+
return merged_model
|
|
201
|
+
|
|
202
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
203
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
204
|
+
merged_model.save_pretrained(
|
|
205
|
+
Path(self.log_dir) / "checkpoints" / f"merged_model_{step}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def merge_linear_weights(
|
|
209
|
+
self,
|
|
210
|
+
merged_W: Tensor,
|
|
211
|
+
pretrained_W: Tensor,
|
|
212
|
+
task_W: Tensor,
|
|
213
|
+
param_name: str,
|
|
214
|
+
alpha: float,
|
|
215
|
+
accelerator: str = "cpu",
|
|
216
|
+
):
|
|
217
|
+
original_device = merged_W.device
|
|
218
|
+
merged_W = merged_W.to(accelerator)
|
|
219
|
+
pretrained_W = pretrained_W.to(accelerator)
|
|
220
|
+
task_W = task_W.to(accelerator)
|
|
221
|
+
|
|
222
|
+
previous_merged_tv = merged_W - pretrained_W
|
|
223
|
+
task_tv = task_W - pretrained_W
|
|
224
|
+
|
|
225
|
+
u, s, v = svd(previous_merged_tv)
|
|
226
|
+
rank = s.size(0)
|
|
227
|
+
split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
|
|
228
|
+
|
|
229
|
+
projected_task_tv = u.T @ task_tv @ v
|
|
230
|
+
projected_task_tv.diag().fill_(0)
|
|
231
|
+
|
|
232
|
+
projected_task_tv[:split_rank, :split_rank] = 0
|
|
233
|
+
|
|
234
|
+
cleaned_task_tv = u @ projected_task_tv @ v.T
|
|
235
|
+
|
|
236
|
+
previous_lambda_t = self.previous_lambda_t
|
|
237
|
+
lambda_t = self.lambda_t
|
|
238
|
+
new_merged_W = (
|
|
239
|
+
pretrained_W
|
|
240
|
+
+ (previous_lambda_t * previous_merged_tv + cleaned_task_tv) / lambda_t
|
|
241
|
+
)
|
|
242
|
+
return new_merged_W.to(original_device)
|
|
243
|
+
|
|
244
|
+
def merge_other_parameters(
|
|
245
|
+
self,
|
|
246
|
+
merged_W: Tensor,
|
|
247
|
+
pretrained_W: Tensor,
|
|
248
|
+
task_W: Tensor,
|
|
249
|
+
param_name: str,
|
|
250
|
+
accelerator: str = "cpu",
|
|
251
|
+
):
|
|
252
|
+
original_device = merged_W.device
|
|
253
|
+
merged_W = merged_W.to(accelerator)
|
|
254
|
+
pretrained_W = pretrained_W.to(accelerator)
|
|
255
|
+
task_W = task_W.to(accelerator)
|
|
256
|
+
|
|
257
|
+
previous_merged_tv = merged_W - pretrained_W
|
|
258
|
+
task_tv = task_W - pretrained_W
|
|
259
|
+
|
|
260
|
+
previous_lambda_t = self.previous_lambda_t
|
|
261
|
+
lambda_t = self.lambda_t
|
|
262
|
+
|
|
263
|
+
new_merged_W = (
|
|
264
|
+
pretrained_W + (previous_lambda_t * previous_merged_tv + task_tv) / lambda_t
|
|
265
|
+
)
|
|
266
|
+
return new_merged_W.to(original_device)
|
|
267
|
+
|
|
268
|
+
def compute_lambda_t(
|
|
269
|
+
self, previous_merged_tv: Tensor, task_tv: Tensor, previous_lambda_t: float
|
|
270
|
+
):
|
|
271
|
+
previous_merged_tv = torch.flatten(previous_merged_tv)
|
|
272
|
+
task_tv = torch.flatten(task_tv)
|
|
273
|
+
|
|
274
|
+
lambda_t = torch.linalg.vector_norm(
|
|
275
|
+
previous_lambda_t * previous_merged_tv + task_tv
|
|
276
|
+
) / torch.linalg.vector_norm(previous_merged_tv)
|
|
277
|
+
return lambda_t.item()
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, cast
|
|
8
|
+
|
|
9
|
+
import lightning as L
|
|
10
|
+
import numpy as np
|
|
11
|
+
import torch
|
|
12
|
+
from omegaconf import DictConfig
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
from tqdm.auto import tqdm
|
|
15
|
+
from transformers import CLIPVisionModel
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
19
|
+
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
|
+
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
21
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from torch.utils.tensorboard import SummaryWriter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
scaling_factor: float,
|
|
31
|
+
shuffle_order: bool = True,
|
|
32
|
+
seed: Optional[int] = None,
|
|
33
|
+
save_on_every_step: bool = True,
|
|
34
|
+
evaluate_on_every_step: bool = False,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Continual Model Merging via Task Arithmetic.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
scaling_factor (float): the scaling factor to use.
|
|
42
|
+
shuffle_order (bool): whether to shuffle the order of the models.
|
|
43
|
+
seed (Optional[int]): the seed to use.
|
|
44
|
+
save_on_every_step (bool): whether to save the merged model on every step.
|
|
45
|
+
evaluate_on_every_step (bool): whether to evaluate the merged model on every step.
|
|
46
|
+
"""
|
|
47
|
+
self.scaling_factor = scaling_factor
|
|
48
|
+
self.shuffle_order = shuffle_order
|
|
49
|
+
self.seed = seed
|
|
50
|
+
self.save_on_every_step = save_on_every_step
|
|
51
|
+
self.evaluate_on_every_step = evaluate_on_every_step
|
|
52
|
+
super().__init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
@torch.no_grad()
|
|
55
|
+
def run(self, modelpool: BaseModelPool):
|
|
56
|
+
if self.seed is not None:
|
|
57
|
+
L.seed_everything(self.seed)
|
|
58
|
+
|
|
59
|
+
model_names = modelpool.model_names
|
|
60
|
+
if self.shuffle_order:
|
|
61
|
+
random.shuffle(model_names)
|
|
62
|
+
|
|
63
|
+
self.taskpool = cast(CLIPVisionModelTaskPool, self._program.taskpool)
|
|
64
|
+
self._test_datasets = deepcopy(self.taskpool._test_datasets)
|
|
65
|
+
"""Configuration for the test datasets"""
|
|
66
|
+
|
|
67
|
+
# log the model names
|
|
68
|
+
if self.log_dir is not None:
|
|
69
|
+
save_to_json(model_names, Path(self.log_dir) / "model_names.json")
|
|
70
|
+
tensorboard_summarywriter: "SummaryWriter" = self.tensorboard_summarywriter
|
|
71
|
+
tensorboard_summarywriter.add_text(
|
|
72
|
+
"global/model_names", str(model_names), global_step=0
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# get the average model
|
|
76
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
77
|
+
merged_model = deepcopy(pretrained_model)
|
|
78
|
+
|
|
79
|
+
for model_idx, model_name in tqdm(
|
|
80
|
+
enumerate(model_names), desc="Processing models"
|
|
81
|
+
):
|
|
82
|
+
task_model = modelpool.load_model(model_name)
|
|
83
|
+
|
|
84
|
+
for param_name, param in task_model.named_parameters():
|
|
85
|
+
if not param.requires_grad:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
task_param = param
|
|
89
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
90
|
+
pretrained_param = pretrained_model.get_parameter(param_name)
|
|
91
|
+
|
|
92
|
+
new_param = merged_param + self.scaling_factor * (
|
|
93
|
+
task_param - pretrained_param
|
|
94
|
+
)
|
|
95
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
96
|
+
|
|
97
|
+
if self.save_on_every_step:
|
|
98
|
+
self.save_merged_model(merged_model, model_idx)
|
|
99
|
+
|
|
100
|
+
if self.evaluate_on_every_step:
|
|
101
|
+
self.taskpool._is_setup = False
|
|
102
|
+
self.taskpool._test_datasets = DictConfig(
|
|
103
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
104
|
+
)
|
|
105
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
106
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
107
|
+
|
|
108
|
+
return merged_model
|
|
109
|
+
|
|
110
|
+
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
111
|
+
os.makedirs(Path(self.log_dir) / "checkpoints", exist_ok=True)
|
|
112
|
+
torch.save(
|
|
113
|
+
merged_model.state_dict(),
|
|
114
|
+
Path(self.log_dir) / "checkpoints" / f"model_{step}.pth",
|
|
115
|
+
)
|