fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +0 -1
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/doge_ta/doge_ta.py +364 -0
- fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/isotropic_merging/iso.py +2 -2
- fusion_bench/method/opcm/opcm.py +93 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/task_singular_vector/TSVM.py +3 -3
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
|
|
|
15
15
|
from transformers import CLIPVisionModel
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
19
19
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
20
|
from fusion_bench.utils import instantiate
|
|
21
21
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
31
31
|
class OPCMForCLIP(
|
|
32
32
|
BaseAlgorithm,
|
|
33
33
|
LightningFabricMixin,
|
|
34
|
+
SimpleProfilerMixin,
|
|
34
35
|
):
|
|
35
36
|
def __init__(
|
|
36
37
|
self,
|
|
@@ -64,7 +65,8 @@ class OPCMForCLIP(
|
|
|
64
65
|
L.seed_everything(self.seed)
|
|
65
66
|
accelerator = self.fabric.device
|
|
66
67
|
|
|
67
|
-
|
|
68
|
+
with self.profile("loading model"):
|
|
69
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
68
70
|
|
|
69
71
|
model_names = modelpool.model_names
|
|
70
72
|
if self.shuffle_order:
|
|
@@ -83,15 +85,17 @@ class OPCMForCLIP(
|
|
|
83
85
|
)
|
|
84
86
|
|
|
85
87
|
# get the average model
|
|
86
|
-
|
|
88
|
+
with self.profile("loading model"):
|
|
89
|
+
merged_model = modelpool.load_model(model_names[0])
|
|
87
90
|
|
|
88
91
|
if self.evaluate_on_every_step:
|
|
89
|
-
self.
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
92
|
+
with self.profile("evaluating model"):
|
|
93
|
+
self.taskpool._is_setup = False
|
|
94
|
+
self.taskpool._test_datasets = DictConfig(
|
|
95
|
+
{model_names[0]: self._test_datasets[model_names[0]]}
|
|
96
|
+
)
|
|
97
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
98
|
+
save_to_json(report, Path(self.log_dir) / "report_0.json")
|
|
95
99
|
|
|
96
100
|
self.avg_task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
97
101
|
self.all_task_vector_norm = [self.avg_task_vector_norm]
|
|
@@ -113,90 +117,95 @@ class OPCMForCLIP(
|
|
|
113
117
|
enumerate(model_names[1:]), desc="Processing models"
|
|
114
118
|
):
|
|
115
119
|
model_idx += 1
|
|
116
|
-
|
|
120
|
+
with self.profile("loading model"):
|
|
121
|
+
task_model = modelpool.load_model(model_name)
|
|
117
122
|
|
|
118
|
-
self.
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
123
|
+
with self.profile("merging model"):
|
|
124
|
+
self.all_task_vector_norm.append(
|
|
125
|
+
get_task_vector_norm(task_model, pretrained_model)
|
|
126
|
+
)
|
|
127
|
+
self.avg_task_vector_norm = np.mean(self.all_task_vector_norm)
|
|
128
|
+
self.fabric.log(
|
|
129
|
+
"model/task_vector_norm", self.all_task_vector_norm[-1], step=model_idx
|
|
130
|
+
)
|
|
131
|
+
self.fabric.log(
|
|
132
|
+
"model/avg_task_vector_norm", self.avg_task_vector_norm, step=model_idx
|
|
133
|
+
)
|
|
128
134
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
accelerator=accelerator,
|
|
147
|
-
)
|
|
148
|
-
if module.bias is not None:
|
|
149
|
-
module.bias.data = self.merge_other_parameters(
|
|
150
|
-
module.bias,
|
|
151
|
-
pretrained_model.get_submodule(module_name).bias,
|
|
152
|
-
task_model.get_submodule(module_name).bias,
|
|
153
|
-
param_name=".".join([module_name, "bias"]),
|
|
135
|
+
self.lambda_t = 1 # temporary value
|
|
136
|
+
|
|
137
|
+
for module_name, module in tqdm(
|
|
138
|
+
list(merged_model.named_modules()),
|
|
139
|
+
desc=f"Processing {model_name}",
|
|
140
|
+
leave=False,
|
|
141
|
+
):
|
|
142
|
+
if not is_leaf_module(module):
|
|
143
|
+
continue
|
|
144
|
+
|
|
145
|
+
if isinstance(module, nn.Linear):
|
|
146
|
+
module.weight.data = self.merge_linear_weights(
|
|
147
|
+
module.weight,
|
|
148
|
+
pretrained_model.get_submodule(module_name).weight,
|
|
149
|
+
task_model.get_submodule(module_name).weight,
|
|
150
|
+
param_name=".".join([module_name, "weight"]),
|
|
151
|
+
alpha=self.alpha,
|
|
154
152
|
accelerator=accelerator,
|
|
155
153
|
)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
module_name
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
154
|
+
if module.bias is not None:
|
|
155
|
+
module.bias.data = self.merge_other_parameters(
|
|
156
|
+
module.bias,
|
|
157
|
+
pretrained_model.get_submodule(module_name).bias,
|
|
158
|
+
task_model.get_submodule(module_name).bias,
|
|
159
|
+
param_name=".".join([module_name, "bias"]),
|
|
160
|
+
accelerator=accelerator,
|
|
161
|
+
)
|
|
162
|
+
else:
|
|
163
|
+
for param_name, param in module.named_parameters():
|
|
164
|
+
param.data = self.merge_other_parameters(
|
|
165
|
+
merged_W=param,
|
|
166
|
+
pretrained_W=pretrained_model.get_submodule(
|
|
167
|
+
module_name
|
|
168
|
+
).get_parameter(param_name),
|
|
169
|
+
task_W=task_model.get_submodule(module_name).get_parameter(
|
|
170
|
+
param_name
|
|
171
|
+
),
|
|
172
|
+
param_name=".".join([module_name, param_name]),
|
|
173
|
+
accelerator=accelerator,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
task_vector_norm = get_task_vector_norm(merged_model, pretrained_model)
|
|
177
|
+
self.lambda_t *= task_vector_norm / self.avg_task_vector_norm
|
|
178
|
+
for param_name, param in merged_model.named_parameters():
|
|
179
|
+
param.data = pretrained_model.get_parameter(param_name) + (
|
|
180
|
+
param - pretrained_model.get_parameter(param_name)
|
|
181
|
+
) * (self.avg_task_vector_norm / task_vector_norm)
|
|
182
|
+
self.fabric.log("model/lambda_t", self.lambda_t, step=model_idx)
|
|
183
|
+
self.fabric.log(
|
|
184
|
+
"empirical/lambda_t", np.sqrt(model_idx + 1), step=model_idx
|
|
185
|
+
)
|
|
186
|
+
self.previous_lambda_t = self.lambda_t
|
|
187
|
+
self.lambda_t = None
|
|
182
188
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
189
|
+
self.fabric.log(
|
|
190
|
+
"model/merged_task_vector_norm",
|
|
191
|
+
get_task_vector_norm(merged_model, pretrained_model),
|
|
192
|
+
step=model_idx,
|
|
193
|
+
)
|
|
188
194
|
|
|
189
195
|
if self.save_on_every_step:
|
|
190
|
-
self.
|
|
196
|
+
with self.profile("saving model"):
|
|
197
|
+
self.save_merged_model(merged_model, model_idx)
|
|
191
198
|
|
|
192
199
|
if self.evaluate_on_every_step:
|
|
193
|
-
self.
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
200
|
+
with self.profile("evaluating model"):
|
|
201
|
+
self.taskpool._is_setup = False
|
|
202
|
+
self.taskpool._test_datasets = DictConfig(
|
|
203
|
+
{n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
|
|
204
|
+
)
|
|
205
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
206
|
+
save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
|
|
199
207
|
|
|
208
|
+
self.print_profile_summary()
|
|
200
209
|
return merged_model
|
|
201
210
|
|
|
202
211
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
@@ -227,7 +236,7 @@ class OPCMForCLIP(
|
|
|
227
236
|
split_rank = (s.cumsum(dim=0) / s.sum() > alpha).float().argmax().item()
|
|
228
237
|
|
|
229
238
|
projected_task_tv = u.T @ task_tv @ v
|
|
230
|
-
projected_task_tv.
|
|
239
|
+
projected_task_tv.diagonal().fill_(0)
|
|
231
240
|
|
|
232
241
|
projected_task_tv[:split_rank, :split_rank] = 0
|
|
233
242
|
|
|
@@ -15,7 +15,7 @@ from tqdm.auto import tqdm
|
|
|
15
15
|
from transformers import CLIPVisionModel
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
18
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
19
19
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
20
20
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
21
21
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
@@ -24,7 +24,11 @@ if TYPE_CHECKING:
|
|
|
24
24
|
from torch.utils.tensorboard import SummaryWriter
|
|
25
25
|
|
|
26
26
|
|
|
27
|
-
class ContinualTaskArithmeticForCLIP(
|
|
27
|
+
class ContinualTaskArithmeticForCLIP(
|
|
28
|
+
BaseAlgorithm,
|
|
29
|
+
LightningFabricMixin,
|
|
30
|
+
SimpleProfilerMixin,
|
|
31
|
+
):
|
|
28
32
|
def __init__(
|
|
29
33
|
self,
|
|
30
34
|
scaling_factor: float,
|
|
@@ -79,32 +83,42 @@ class ContinualTaskArithmeticForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
79
83
|
for model_idx, model_name in tqdm(
|
|
80
84
|
enumerate(model_names), desc="Processing models"
|
|
81
85
|
):
|
|
82
|
-
|
|
86
|
+
with self.profile("loading model"):
|
|
87
|
+
task_model = modelpool.load_model(model_name)
|
|
83
88
|
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
89
|
+
with self.profile("merging model"):
|
|
90
|
+
for param_name, param in task_model.named_parameters():
|
|
91
|
+
if not param.requires_grad:
|
|
92
|
+
continue
|
|
87
93
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
94
|
+
task_param = param
|
|
95
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
96
|
+
pretrained_param = pretrained_model.get_parameter(param_name)
|
|
91
97
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
98
|
+
new_param = merged_param + self.scaling_factor * (
|
|
99
|
+
task_param - pretrained_param
|
|
100
|
+
)
|
|
101
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
96
102
|
|
|
97
103
|
if self.save_on_every_step:
|
|
98
|
-
self.
|
|
104
|
+
with self.profile("saving model"):
|
|
105
|
+
self.save_merged_model(merged_model, model_idx)
|
|
99
106
|
|
|
100
107
|
if self.evaluate_on_every_step:
|
|
101
|
-
self.
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
+
with self.profile("evaluating model"):
|
|
109
|
+
self.taskpool._is_setup = False
|
|
110
|
+
self.taskpool._test_datasets = DictConfig(
|
|
111
|
+
{
|
|
112
|
+
n: self._test_datasets[n]
|
|
113
|
+
for n in model_names[: model_idx + 1]
|
|
114
|
+
}
|
|
115
|
+
)
|
|
116
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
117
|
+
save_to_json(
|
|
118
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.print_profile_summary()
|
|
108
122
|
return merged_model
|
|
109
123
|
|
|
110
124
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
@@ -20,7 +20,7 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
|
|
|
20
20
|
ties_merging,
|
|
21
21
|
vector_to_state_dict,
|
|
22
22
|
)
|
|
23
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
23
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
24
24
|
from fusion_bench.taskpool import CLIPVisionModelTaskPool
|
|
25
25
|
from fusion_bench.utils.json import load_from_json, save_to_json
|
|
26
26
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
|
|
@@ -29,7 +29,11 @@ if TYPE_CHECKING:
|
|
|
29
29
|
from torch.utils.tensorboard import SummaryWriter
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
class ContinualTiesMergingForCLIP(
|
|
32
|
+
class ContinualTiesMergingForCLIP(
|
|
33
|
+
BaseAlgorithm,
|
|
34
|
+
LightningFabricMixin,
|
|
35
|
+
SimpleProfilerMixin,
|
|
36
|
+
):
|
|
33
37
|
def __init__(
|
|
34
38
|
self,
|
|
35
39
|
scaling_factor: float,
|
|
@@ -84,68 +88,83 @@ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
84
88
|
)
|
|
85
89
|
|
|
86
90
|
# get the average model
|
|
87
|
-
|
|
91
|
+
with self.profile("loading model"):
|
|
92
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
88
93
|
merged_model = deepcopy(pretrained_model)
|
|
89
94
|
|
|
90
95
|
for model_idx, model_name in tqdm(
|
|
91
96
|
enumerate(model_names), desc="Processing models"
|
|
92
97
|
):
|
|
93
|
-
|
|
98
|
+
with self.profile("loading model"):
|
|
99
|
+
task_model = modelpool.load_model(model_name)
|
|
94
100
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
)
|
|
99
|
-
if model_idx == 0:
|
|
100
|
-
# if is the first model, the merged task vector is equal to the task vector
|
|
101
|
-
ties_merging_state_dict = task_vector
|
|
102
|
-
else:
|
|
103
|
-
# if is not the first model, we need to merge the task vector with the previous merged task vector
|
|
104
|
-
merged_tv = state_dict_sub(
|
|
105
|
-
merged_model.state_dict(),
|
|
101
|
+
with self.profile("merging model"):
|
|
102
|
+
task_vector = state_dict_sub(
|
|
103
|
+
task_model.state_dict(),
|
|
106
104
|
pretrained_model.state_dict(),
|
|
107
105
|
)
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
106
|
+
if model_idx == 0:
|
|
107
|
+
# if is the first model, the merged task vector is equal to the task vector
|
|
108
|
+
ties_merging_state_dict = task_vector
|
|
109
|
+
else:
|
|
110
|
+
# if is not the first model, we need to merge the task vector with the previous merged task vector
|
|
111
|
+
merged_tv = state_dict_sub(
|
|
112
|
+
merged_model.state_dict(),
|
|
113
|
+
pretrained_model.state_dict(),
|
|
114
|
+
)
|
|
115
|
+
tv_flat_checks = torch.vstack(
|
|
116
|
+
[
|
|
117
|
+
state_dict_to_vector(
|
|
118
|
+
merged_tv, remove_keys=self.remove_keys
|
|
119
|
+
),
|
|
120
|
+
state_dict_to_vector(
|
|
121
|
+
task_vector, remove_keys=self.remove_keys
|
|
122
|
+
),
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
# perform the TIES merging
|
|
126
|
+
ties_merging_tv = ties_merging(
|
|
127
|
+
tv_flat_checks,
|
|
128
|
+
reset_thresh=self.threshold,
|
|
129
|
+
merge_func=self.merge_func,
|
|
130
|
+
)
|
|
131
|
+
# convert the merged task vector back to a state dict
|
|
132
|
+
ties_merging_state_dict = vector_to_state_dict(
|
|
133
|
+
ties_merging_tv,
|
|
134
|
+
merged_model.state_dict(),
|
|
135
|
+
remove_keys=self.remove_keys,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for param_name, param in task_model.named_parameters():
|
|
139
|
+
if not param.requires_grad:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
merged_param = merged_model.get_parameter(param_name)
|
|
143
|
+
new_param = (
|
|
144
|
+
merged_param
|
|
145
|
+
+ self.scaling_factor * ties_merging_state_dict[param_name]
|
|
146
|
+
)
|
|
147
|
+
merged_model.get_parameter(param_name).data = new_param
|
|
137
148
|
|
|
138
149
|
if self.save_on_every_step:
|
|
139
|
-
self.
|
|
150
|
+
with self.profile("saving model"):
|
|
151
|
+
self.save_merged_model(merged_model, model_idx)
|
|
140
152
|
|
|
141
153
|
if self.evaluate_on_every_step:
|
|
142
|
-
self.
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
154
|
+
with self.profile("evaluating model"):
|
|
155
|
+
self.taskpool._is_setup = False
|
|
156
|
+
self.taskpool._test_datasets = DictConfig(
|
|
157
|
+
{
|
|
158
|
+
n: self._test_datasets[n]
|
|
159
|
+
for n in model_names[: model_idx + 1]
|
|
160
|
+
}
|
|
161
|
+
)
|
|
162
|
+
report = self.taskpool.evaluate(deepcopy(merged_model))
|
|
163
|
+
save_to_json(
|
|
164
|
+
report, Path(self.log_dir) / f"report_{model_idx}.json"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self.print_profile_summary()
|
|
149
168
|
return merged_model
|
|
150
169
|
|
|
151
170
|
def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
|
|
@@ -9,19 +9,19 @@ fusion_bench \
|
|
|
9
9
|
```
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from typing import List, Optional, Union
|
|
12
|
+
from typing import Iterable, List, Optional, Union
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
|
-
from torch import Tensor, nn
|
|
16
15
|
from omegaconf import ListConfig
|
|
16
|
+
from torch import Tensor, nn
|
|
17
17
|
|
|
18
18
|
from fusion_bench import BaseAlgorithm
|
|
19
19
|
from fusion_bench.mixins import LightningFabricMixin
|
|
20
20
|
from fusion_bench.utils import timeit_context
|
|
21
21
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
22
22
|
state_dict_add,
|
|
23
|
-
state_dict_sub,
|
|
24
23
|
state_dict_mul,
|
|
24
|
+
state_dict_sub,
|
|
25
25
|
)
|
|
26
26
|
from fusion_bench.utils.type import StateDictType
|
|
27
27
|
|
|
@@ -16,6 +16,7 @@ import torch
|
|
|
16
16
|
from torch import Tensor, nn
|
|
17
17
|
from torch.func import functional_call
|
|
18
18
|
|
|
19
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
19
20
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
20
21
|
|
|
21
22
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
@@ -23,52 +24,6 @@ __all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
|
23
24
|
log = logging.getLogger(__name__)
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
def del_attr(obj, names: List[str]):
|
|
27
|
-
"""
|
|
28
|
-
Deletes an attribute from an object recursively.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
obj (object): Object to delete attribute from.
|
|
32
|
-
names (list): List of attribute names to delete recursively.
|
|
33
|
-
"""
|
|
34
|
-
if len(names) == 1:
|
|
35
|
-
delattr(obj, names[0])
|
|
36
|
-
else:
|
|
37
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def set_attr(obj, names: List[str], val):
|
|
41
|
-
"""
|
|
42
|
-
Sets an attribute of an object recursively.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
obj (object): Object to set attribute of.
|
|
46
|
-
names (list): List of attribute names to set recursively.
|
|
47
|
-
val (object): Value to set the attribute to.
|
|
48
|
-
"""
|
|
49
|
-
if len(names) == 1:
|
|
50
|
-
setattr(obj, names[0], val)
|
|
51
|
-
else:
|
|
52
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_attr(obj, names: List[str]):
|
|
56
|
-
"""
|
|
57
|
-
Gets an attribute of an object recursively.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
obj (object): Object to get attribute of.
|
|
61
|
-
names (list): List of attribute names to get recursively.
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
object: The attribute of the object.
|
|
65
|
-
"""
|
|
66
|
-
if len(names) == 1:
|
|
67
|
-
return getattr(obj, names[0])
|
|
68
|
-
else:
|
|
69
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
70
|
-
|
|
71
|
-
|
|
72
27
|
def get_layer_wise_weights(
|
|
73
28
|
num_models: int,
|
|
74
29
|
num_layers: int,
|