fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +0 -1
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/doge_ta/doge_ta.py +364 -0
- fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/isotropic_merging/iso.py +2 -2
- fusion_bench/method/opcm/opcm.py +93 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/task_singular_vector/TSVM.py +3 -3
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Defense-Aware Task-wise & Layer-wise Concrete AdaMerging for CLIP ViT models
|
|
3
|
+
|
|
4
|
+
Examples:
|
|
5
|
+
|
|
6
|
+
```bash
|
|
7
|
+
fusion_bench \
|
|
8
|
+
fabric_logger.name= \
|
|
9
|
+
method=clip_safe_concrete_task_wise_adamerging \
|
|
10
|
+
modelpool= \
|
|
11
|
+
taskpool=
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
fusion_bench \
|
|
16
|
+
fabric_logger.name= \
|
|
17
|
+
method=clip_safe_concrete_layer_wise_adamerging \
|
|
18
|
+
modelpool= \
|
|
19
|
+
taskpool=
|
|
20
|
+
```
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
import os
|
|
25
|
+
from copy import deepcopy
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
import torch.nn.functional as F
|
|
30
|
+
from tqdm.autonotebook import tqdm
|
|
31
|
+
|
|
32
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
33
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
34
|
+
from fusion_bench.compat.modelpool.huggingface_clip_vision import (
|
|
35
|
+
HuggingFaceClipVisionPool,
|
|
36
|
+
)
|
|
37
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
38
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
39
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
40
|
+
from fusion_bench.models.masks import MaskModel, mask_sparsity
|
|
41
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
42
|
+
LayerWiseMergedModel,
|
|
43
|
+
get_layer_wise_weights,
|
|
44
|
+
)
|
|
45
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
46
|
+
TaskWiseMergedModel,
|
|
47
|
+
get_task_wise_weights,
|
|
48
|
+
)
|
|
49
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
50
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
51
|
+
|
|
52
|
+
log = logging.getLogger(__name__)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ConcreteSafeTaskWiseAdaMergingForCLIP(
|
|
56
|
+
CLIPClassificationMixin,
|
|
57
|
+
SimpleProfilerMixin,
|
|
58
|
+
ModelFusionAlgorithm,
|
|
59
|
+
):
|
|
60
|
+
@torch.no_grad()
|
|
61
|
+
def setup_models(self):
|
|
62
|
+
config = self.config
|
|
63
|
+
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
|
|
64
|
+
modelpool = self.modelpool
|
|
65
|
+
# Load the pretrained model
|
|
66
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
67
|
+
|
|
68
|
+
# construct PGE mask model
|
|
69
|
+
mask_model = MaskModel(
|
|
70
|
+
pretrained_model,
|
|
71
|
+
ignore_untrained_params=True,
|
|
72
|
+
parameter_type="logits",
|
|
73
|
+
)
|
|
74
|
+
if self.merge_dtype is not None:
|
|
75
|
+
mask_model.to(self.merge_dtype)
|
|
76
|
+
mask_model.fill_(self.config.initial_logits)
|
|
77
|
+
# TODO: ablation study for the initialization of mask model
|
|
78
|
+
# for param in mask_model.parameters():
|
|
79
|
+
# param.data = param + 0.1 * torch.randn_like(param)
|
|
80
|
+
print("Summary of mask model:")
|
|
81
|
+
print_parameters(mask_model)
|
|
82
|
+
|
|
83
|
+
# Load the fine-tuned models
|
|
84
|
+
finetuned_models = [
|
|
85
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
task_wise_weight = get_task_wise_weights(
|
|
89
|
+
num_models=len(modelpool.model_names),
|
|
90
|
+
init_values=self.config.scaling_factor,
|
|
91
|
+
)
|
|
92
|
+
self.init_task_wise_weight = deepcopy(task_wise_weight)
|
|
93
|
+
|
|
94
|
+
# create a warpped model
|
|
95
|
+
module = TaskWiseMergedModel(
|
|
96
|
+
task_wise_weight=task_wise_weight,
|
|
97
|
+
pretrained_model=pretrained_model,
|
|
98
|
+
finetuned_models=finetuned_models,
|
|
99
|
+
clamp_weights=self.config.clamp_weights,
|
|
100
|
+
tie_weights=self.config.tie_weights,
|
|
101
|
+
strict=self.config.strict,
|
|
102
|
+
task_vector_dtype=self.merge_dtype,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.pertubed_model = nn.Module()
|
|
106
|
+
self.pertubed_model.perturbed_input = nn.Parameter(
|
|
107
|
+
torch.zeros([len(modelpool.model_names), 3, 224, 224]), requires_grad=True
|
|
108
|
+
)
|
|
109
|
+
return module, mask_model
|
|
110
|
+
|
|
111
|
+
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
|
|
112
|
+
config = self.config
|
|
113
|
+
self.init_task_wise_weight = self.to_device(self.init_task_wise_weight)
|
|
114
|
+
|
|
115
|
+
# configure optimizer
|
|
116
|
+
lr_scheduler = None
|
|
117
|
+
if self.config.optimizer == "adam":
|
|
118
|
+
|
|
119
|
+
### for merge_weight
|
|
120
|
+
base_optimizer = torch.optim.Adam(
|
|
121
|
+
[module.merge_weight], lr=self.config.base_lr
|
|
122
|
+
)
|
|
123
|
+
module, base_optimizer = self.fabric.setup(module, base_optimizer)
|
|
124
|
+
|
|
125
|
+
### for mask
|
|
126
|
+
optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
|
|
127
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
128
|
+
|
|
129
|
+
### for perturbed noise
|
|
130
|
+
batch_opt_adv = torch.optim.Adam(
|
|
131
|
+
params=self.pertubed_model.parameters(), lr=self.config.adv_lr
|
|
132
|
+
)
|
|
133
|
+
self.pertubed_model, batch_opt_adv = self.fabric.setup(
|
|
134
|
+
self.pertubed_model, batch_opt_adv
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
138
|
+
|
|
139
|
+
module.train()
|
|
140
|
+
mask_model.train()
|
|
141
|
+
self.pertubed_model.train()
|
|
142
|
+
for step_idx in (
|
|
143
|
+
pbar := tqdm(
|
|
144
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
145
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
146
|
+
+ "Concrete Safe AdaMerging Meta-Learn Mask (1/2)",
|
|
147
|
+
dynamic_ncols=True,
|
|
148
|
+
disable=not self.fabric.is_global_zero,
|
|
149
|
+
)
|
|
150
|
+
):
|
|
151
|
+
metrics = {}
|
|
152
|
+
# sample a shared mask and merge weights
|
|
153
|
+
with self.profile("sample mask"):
|
|
154
|
+
mask = mask_model.sample_mask(
|
|
155
|
+
mask_type="continuous", temperature=config.temperature
|
|
156
|
+
)
|
|
157
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
158
|
+
with self.profile("merge weights"):
|
|
159
|
+
# rescale mask
|
|
160
|
+
for name, m in mask.items():
|
|
161
|
+
mask[name] = m / torch.mean(m)
|
|
162
|
+
|
|
163
|
+
# for inner optimization, we do not optimize the mask, so we detach it
|
|
164
|
+
module.merge_weights(
|
|
165
|
+
task_vector_mask={name: m.detach() for name, m in mask.items()}
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# ------ inner optimization goes here ------
|
|
169
|
+
|
|
170
|
+
### (1)optimize the merging weight
|
|
171
|
+
module.merge_weight.data = deepcopy(self.init_task_wise_weight)
|
|
172
|
+
total_loss = None
|
|
173
|
+
for task in self.modelpool.model_names:
|
|
174
|
+
with self.profile("data loading"):
|
|
175
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
176
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
177
|
+
images = batch[0]
|
|
178
|
+
|
|
179
|
+
with self.profile("forward pass"):
|
|
180
|
+
logits = self.compute_logits(module, images, task)
|
|
181
|
+
loss = entropy_loss(logits)
|
|
182
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
183
|
+
|
|
184
|
+
with self.profile("compute grad"):
|
|
185
|
+
self.fabric.backward(total_loss)
|
|
186
|
+
|
|
187
|
+
with self.profile("base optimizer step"):
|
|
188
|
+
base_optimizer.step()
|
|
189
|
+
base_optimizer.zero_grad()
|
|
190
|
+
|
|
191
|
+
with self.profile("merge weights"):
|
|
192
|
+
module.merge_weights(task_vector_mask=mask)
|
|
193
|
+
|
|
194
|
+
# (2)noise optimization based on the merging model
|
|
195
|
+
|
|
196
|
+
# detach merged state_dict
|
|
197
|
+
merged_state_dict = module._merged_state_dict
|
|
198
|
+
detached_merged_state_dict = {
|
|
199
|
+
k: p.detach() for k, p in merged_state_dict.items()
|
|
200
|
+
}
|
|
201
|
+
module._merged_state_dict = detached_merged_state_dict
|
|
202
|
+
|
|
203
|
+
total_loss = None
|
|
204
|
+
for task_idx, task in enumerate(self.modelpool.model_names):
|
|
205
|
+
with self.profile("data loading"):
|
|
206
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
207
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
208
|
+
images = batch[0]
|
|
209
|
+
perturbed_images = (
|
|
210
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
211
|
+
)
|
|
212
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
213
|
+
|
|
214
|
+
with self.profile("forward pass"):
|
|
215
|
+
combined_logits = self.compute_logits(module, combined_images, task)
|
|
216
|
+
logits = combined_logits[: images.size(0)]
|
|
217
|
+
logits_adv = combined_logits[images.size(0) :]
|
|
218
|
+
ori_label = torch.argmax(logits, axis=1).long()
|
|
219
|
+
loss = torch.mean(
|
|
220
|
+
-F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
221
|
+
)
|
|
222
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
223
|
+
|
|
224
|
+
with self.profile("compute grad"):
|
|
225
|
+
self.fabric.backward(total_loss)
|
|
226
|
+
|
|
227
|
+
with self.profile("batch_opt_adv optimizer step"):
|
|
228
|
+
batch_opt_adv.step()
|
|
229
|
+
batch_opt_adv.zero_grad()
|
|
230
|
+
|
|
231
|
+
# (3)mask optimization
|
|
232
|
+
total_loss = None
|
|
233
|
+
module._merged_state_dict = merged_state_dict
|
|
234
|
+
|
|
235
|
+
for task_idx, task in enumerate(self.modelpool.model_names):
|
|
236
|
+
with self.profile("data loading"), torch.no_grad():
|
|
237
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
238
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
239
|
+
images = batch[0]
|
|
240
|
+
perturbed_images = (
|
|
241
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
242
|
+
)
|
|
243
|
+
perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
|
|
244
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
245
|
+
|
|
246
|
+
with self.profile("forward pass"):
|
|
247
|
+
combined_logits = self.compute_logits(module, combined_images, task)
|
|
248
|
+
logits = combined_logits[: images.size(0)]
|
|
249
|
+
logits_adv = combined_logits[images.size(0) :]
|
|
250
|
+
|
|
251
|
+
# # ### regu1
|
|
252
|
+
# ori_label = torch.argmax(logits, axis=1).long()
|
|
253
|
+
# loss_nat = entropy_loss(logits)
|
|
254
|
+
# loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
|
|
255
|
+
|
|
256
|
+
### regu2
|
|
257
|
+
loss_regu = entropy_loss(logits_adv)
|
|
258
|
+
loss_nat = entropy_loss(logits)
|
|
259
|
+
|
|
260
|
+
loss = loss_nat + self.config.adv_weight * loss_regu
|
|
261
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
262
|
+
|
|
263
|
+
with self.profile("compute grad"):
|
|
264
|
+
self.fabric.backward(total_loss)
|
|
265
|
+
|
|
266
|
+
with self.profile("optimizer step"):
|
|
267
|
+
optimizer.step()
|
|
268
|
+
optimizer.zero_grad()
|
|
269
|
+
|
|
270
|
+
if lr_scheduler is not None:
|
|
271
|
+
lr_scheduler.step()
|
|
272
|
+
|
|
273
|
+
# metrics.update({"train/loss": loss.item()})
|
|
274
|
+
metrics.update(
|
|
275
|
+
{
|
|
276
|
+
"train/loss": loss.item(),
|
|
277
|
+
"loss_nat": loss_nat.item(),
|
|
278
|
+
"loss_regu": loss_regu.item(),
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
282
|
+
pbar.set_postfix(metrics)
|
|
283
|
+
self.print_profile_summary()
|
|
284
|
+
|
|
285
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
286
|
+
with self.profiler.profile("save checkpoint"):
|
|
287
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
288
|
+
if not os.path.exists(save_dir):
|
|
289
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
290
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
291
|
+
print(f"saving checkpoint to {save_path}")
|
|
292
|
+
state = {"model": mask_model}
|
|
293
|
+
self.fabric.save(save_path, state)
|
|
294
|
+
|
|
295
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
296
|
+
if self.fabric.is_global_zero:
|
|
297
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
298
|
+
if os.path.exists(symlink_path):
|
|
299
|
+
os.remove(symlink_path)
|
|
300
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
301
|
+
|
|
302
|
+
self.print_profile_summary()
|
|
303
|
+
|
|
304
|
+
def run_adamerging(self, module: TaskWiseMergedModel, mask):
|
|
305
|
+
module.merge_weight.data = deepcopy(self.init_task_wise_weight)
|
|
306
|
+
base_optimizer = torch.optim.Adam(
|
|
307
|
+
[module.merge_weight], lr=self.config.adamerging_lr
|
|
308
|
+
)
|
|
309
|
+
module, base_optimizer = self.fabric.setup(module, base_optimizer)
|
|
310
|
+
module.train()
|
|
311
|
+
for step_idx in (
|
|
312
|
+
pbar := tqdm(
|
|
313
|
+
range(
|
|
314
|
+
self.config.max_adamerging_steps if not self.is_debug_mode else 5
|
|
315
|
+
),
|
|
316
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
317
|
+
+ "Concrete AdaMerging AdaMerging (2/2)",
|
|
318
|
+
dynamic_ncols=True,
|
|
319
|
+
disable=not self.fabric.is_global_zero,
|
|
320
|
+
)
|
|
321
|
+
):
|
|
322
|
+
step_idx = step_idx + self.config.max_steps
|
|
323
|
+
with self.profile("merge weights"):
|
|
324
|
+
module.merge_weights(task_vector_mask=mask)
|
|
325
|
+
|
|
326
|
+
metrics = {}
|
|
327
|
+
total_loss = None
|
|
328
|
+
for task in self.modelpool.model_names:
|
|
329
|
+
with self.profile("data loading"):
|
|
330
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
331
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
332
|
+
images = batch[0]
|
|
333
|
+
with self.profile("forward pass"):
|
|
334
|
+
logits = self.compute_logits(module, images, task)
|
|
335
|
+
loss = entropy_loss(logits)
|
|
336
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
337
|
+
|
|
338
|
+
with self.profile("compute grad"):
|
|
339
|
+
self.fabric.backward(total_loss)
|
|
340
|
+
|
|
341
|
+
with self.profile("base optimizer step"):
|
|
342
|
+
base_optimizer.step()
|
|
343
|
+
base_optimizer.zero_grad()
|
|
344
|
+
|
|
345
|
+
metrics.update({"train/loss": loss.item()})
|
|
346
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
347
|
+
pbar.set_postfix(metrics)
|
|
348
|
+
|
|
349
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
350
|
+
with self.profiler.profile("save checkpoint"):
|
|
351
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
352
|
+
if not os.path.exists(save_dir):
|
|
353
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
354
|
+
save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
|
|
355
|
+
print(f"saving checkpoint to {save_path}")
|
|
356
|
+
state = {"merge_weight": module.merge_weight}
|
|
357
|
+
self.fabric.save(save_path, state)
|
|
358
|
+
|
|
359
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
360
|
+
if self.fabric.is_global_zero:
|
|
361
|
+
symlink_path = os.path.join(
|
|
362
|
+
save_dir, "merge_weight_latest_checkpoint.pt"
|
|
363
|
+
)
|
|
364
|
+
if os.path.exists(symlink_path):
|
|
365
|
+
os.remove(symlink_path)
|
|
366
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
367
|
+
|
|
368
|
+
self.print_profile_summary()
|
|
369
|
+
return module
|
|
370
|
+
|
|
371
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
372
|
+
self.modelpool = to_modelpool(modelpool)
|
|
373
|
+
config = self.config
|
|
374
|
+
self.log_hyperparams(config, filename="method_config.yaml")
|
|
375
|
+
|
|
376
|
+
with self.profile("setup models"):
|
|
377
|
+
module, mask_model = self.setup_models()
|
|
378
|
+
mask_model: MaskModel = self.fabric.to_device(mask_model)
|
|
379
|
+
module: TaskWiseMergedModel = self.fabric.to_device(module)
|
|
380
|
+
self.pertubed_model = self.fabric.to_device(self.pertubed_model)
|
|
381
|
+
self.setup_zero_shot_classification_head()
|
|
382
|
+
|
|
383
|
+
if config.mask_checkpoint is None:
|
|
384
|
+
self.train_mask(module=module, mask_model=mask_model)
|
|
385
|
+
else:
|
|
386
|
+
if self.fabric.is_global_zero:
|
|
387
|
+
print("loading mask from checkpoint", config.mask_checkpoint)
|
|
388
|
+
self.fabric.load(config.mask_checkpoint, {"model": mask_model})
|
|
389
|
+
|
|
390
|
+
# run adamerging
|
|
391
|
+
with torch.no_grad():
|
|
392
|
+
mask = mask_model.sample_mask(
|
|
393
|
+
mask_type=config.eval_mask_type,
|
|
394
|
+
temperature=config.temperature,
|
|
395
|
+
)
|
|
396
|
+
# rescale mask
|
|
397
|
+
for name, m in mask.items():
|
|
398
|
+
mask[name] = m / torch.mean(m)
|
|
399
|
+
module = self.run_adamerging(module, mask=mask)
|
|
400
|
+
|
|
401
|
+
with torch.no_grad():
|
|
402
|
+
model = module.merge_and_unload(mask)
|
|
403
|
+
return model
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class ConcreteSafeLayerWiseAdaMergingForCLIP(
|
|
407
|
+
CLIPClassificationMixin,
|
|
408
|
+
SimpleProfilerMixin,
|
|
409
|
+
ModelFusionAlgorithm,
|
|
410
|
+
):
|
|
411
|
+
@torch.no_grad()
|
|
412
|
+
def setup_models(self):
|
|
413
|
+
modelpool = self.modelpool
|
|
414
|
+
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
|
|
415
|
+
# Load the pretrained model
|
|
416
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
417
|
+
|
|
418
|
+
# construct PGE mask model
|
|
419
|
+
mask_model = MaskModel(
|
|
420
|
+
pretrained_model,
|
|
421
|
+
ignore_untrained_params=True,
|
|
422
|
+
parameter_type="logits",
|
|
423
|
+
)
|
|
424
|
+
if self.merge_dtype is not None:
|
|
425
|
+
mask_model.to(self.merge_dtype)
|
|
426
|
+
mask_model.fill_(self.config.initial_logits)
|
|
427
|
+
# TODO: ablation study for the initialization of mask model
|
|
428
|
+
# for param in mask_model.parameters():
|
|
429
|
+
# param.data = param + 0.1 * torch.randn_like(param)
|
|
430
|
+
print("Summary of mask model:")
|
|
431
|
+
print_parameters(mask_model)
|
|
432
|
+
|
|
433
|
+
# Load the fine-tuned models
|
|
434
|
+
finetuned_models = [
|
|
435
|
+
modelpool.load_model(name) for name in modelpool.model_names
|
|
436
|
+
]
|
|
437
|
+
|
|
438
|
+
layer_wise_weight = get_layer_wise_weights(
|
|
439
|
+
num_models=len(modelpool.model_names),
|
|
440
|
+
num_layers=len(
|
|
441
|
+
tuple(filter(lambda p: p.requires_grad, pretrained_model.parameters()))
|
|
442
|
+
),
|
|
443
|
+
init_values=self.config.scaling_factor,
|
|
444
|
+
)
|
|
445
|
+
self.init_layer_wise_weight = deepcopy(layer_wise_weight)
|
|
446
|
+
|
|
447
|
+
# create a warpped model
|
|
448
|
+
module = LayerWiseMergedModel(
|
|
449
|
+
layer_wise_weight=layer_wise_weight,
|
|
450
|
+
pretrained_model=pretrained_model,
|
|
451
|
+
finetuned_models=finetuned_models,
|
|
452
|
+
clamp_weights=self.config.clamp_weights,
|
|
453
|
+
tie_weights=self.config.tie_weights,
|
|
454
|
+
strict=self.config.strict,
|
|
455
|
+
layer_vector_dtype=self.merge_dtype,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
self.pertubed_model = nn.Module()
|
|
459
|
+
self.pertubed_model.perturbed_input = nn.Parameter(
|
|
460
|
+
torch.zeros([len(modelpool.model_names), 3, 224, 224]), requires_grad=True
|
|
461
|
+
)
|
|
462
|
+
return module, mask_model
|
|
463
|
+
|
|
464
|
+
def train_mask(self, module: LayerWiseMergedModel, mask_model: MaskModel):
|
|
465
|
+
config = self.config
|
|
466
|
+
self.init_layer_wise_weight = self.to_device(self.init_layer_wise_weight)
|
|
467
|
+
|
|
468
|
+
# configure optimizer
|
|
469
|
+
lr_scheduler = None
|
|
470
|
+
if self.config.optimizer == "adam":
|
|
471
|
+
base_optimizer = torch.optim.Adam(
|
|
472
|
+
[module.merge_weight], lr=self.config.base_lr
|
|
473
|
+
)
|
|
474
|
+
optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
|
|
475
|
+
print(f"{optimizer=}")
|
|
476
|
+
# TODO: ablation study for the learning rate scheduler. It should yield similar results.
|
|
477
|
+
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
478
|
+
# optimizer, self.config.max_steps, eta_min=0.1
|
|
479
|
+
# )
|
|
480
|
+
module, base_optimizer = self.fabric.setup(module, base_optimizer)
|
|
481
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
482
|
+
|
|
483
|
+
batch_opt_adv = torch.optim.Adam(
|
|
484
|
+
params=self.pertubed_model.parameters(), lr=self.config.adv_lr
|
|
485
|
+
)
|
|
486
|
+
self.pertubed_model, batch_opt_adv = self.fabric.setup(
|
|
487
|
+
self.pertubed_model, batch_opt_adv
|
|
488
|
+
)
|
|
489
|
+
else:
|
|
490
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
491
|
+
|
|
492
|
+
module.train()
|
|
493
|
+
mask_model.train()
|
|
494
|
+
# self.pertubed_model.train()
|
|
495
|
+
for step_idx in (
|
|
496
|
+
pbar := tqdm(
|
|
497
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
498
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
499
|
+
+ "Concrete Safe AdaMerging Meta-Learn Mask (1/2)",
|
|
500
|
+
dynamic_ncols=True,
|
|
501
|
+
disable=not self.fabric.is_global_zero,
|
|
502
|
+
)
|
|
503
|
+
):
|
|
504
|
+
metrics = {}
|
|
505
|
+
# sample a shared mask and merge weights
|
|
506
|
+
with self.profile("sample mask"):
|
|
507
|
+
mask = mask_model.sample_mask(
|
|
508
|
+
mask_type="continuous", temperature=config.temperature
|
|
509
|
+
)
|
|
510
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
511
|
+
with self.profile("merge weights"):
|
|
512
|
+
# rescale mask
|
|
513
|
+
for name, m in mask.items():
|
|
514
|
+
mask[name] = m / torch.mean(m)
|
|
515
|
+
|
|
516
|
+
# for inner optimization, we do not optimize the mask, so we detach it
|
|
517
|
+
module.merge_weights(
|
|
518
|
+
task_vector_mask={name: m.detach() for name, m in mask.items()}
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
# ------ inner optimization goes here ------
|
|
522
|
+
module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
|
|
523
|
+
total_loss = None
|
|
524
|
+
for task in self.modelpool.model_names:
|
|
525
|
+
with self.profile("data loading"):
|
|
526
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
527
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
528
|
+
images = batch[0]
|
|
529
|
+
with self.profile("forward pass"):
|
|
530
|
+
logits = self.compute_logits(module, images, task)
|
|
531
|
+
loss = entropy_loss(logits)
|
|
532
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
533
|
+
|
|
534
|
+
with self.profile("compute grad"):
|
|
535
|
+
self.fabric.backward(total_loss)
|
|
536
|
+
|
|
537
|
+
with self.profile("base optimizer step"):
|
|
538
|
+
base_optimizer.step()
|
|
539
|
+
base_optimizer.zero_grad()
|
|
540
|
+
|
|
541
|
+
with self.profile("merge weights"):
|
|
542
|
+
module.merge_weights(task_vector_mask=mask)
|
|
543
|
+
|
|
544
|
+
# ------------------------------------------
|
|
545
|
+
|
|
546
|
+
# (2)noise optimization based on the merging model
|
|
547
|
+
|
|
548
|
+
# detach merged state_dict
|
|
549
|
+
merged_state_dict = module._merged_state_dict
|
|
550
|
+
detached_merged_state_dict = {
|
|
551
|
+
k: p.detach() for k, p in merged_state_dict.items()
|
|
552
|
+
}
|
|
553
|
+
module._merged_state_dict = detached_merged_state_dict
|
|
554
|
+
|
|
555
|
+
total_loss = None
|
|
556
|
+
for task_idx, task in enumerate(self.modelpool.model_names):
|
|
557
|
+
with self.profile("data loading"):
|
|
558
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
559
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
560
|
+
images = batch[0]
|
|
561
|
+
perturbed_images = (
|
|
562
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
563
|
+
)
|
|
564
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
565
|
+
|
|
566
|
+
with self.profile("forward pass"):
|
|
567
|
+
combined_logits = self.compute_logits(module, combined_images, task)
|
|
568
|
+
logits = combined_logits[: images.size(0)]
|
|
569
|
+
logits_adv = combined_logits[images.size(0) :]
|
|
570
|
+
ori_label = torch.argmax(logits, axis=1).long()
|
|
571
|
+
loss = torch.mean(
|
|
572
|
+
-F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
573
|
+
)
|
|
574
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
575
|
+
|
|
576
|
+
with self.profile("compute grad"):
|
|
577
|
+
self.fabric.backward(total_loss)
|
|
578
|
+
|
|
579
|
+
with self.profile("batch_opt_adv optimizer step"):
|
|
580
|
+
batch_opt_adv.step()
|
|
581
|
+
batch_opt_adv.zero_grad()
|
|
582
|
+
|
|
583
|
+
# (3)mask optimization
|
|
584
|
+
total_loss = None
|
|
585
|
+
module._merged_state_dict = merged_state_dict
|
|
586
|
+
|
|
587
|
+
for task_idx, task in enumerate(self.modelpool.model_names):
|
|
588
|
+
with self.profile("data loading"), torch.no_grad():
|
|
589
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
590
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
591
|
+
images = batch[0]
|
|
592
|
+
perturbed_images = (
|
|
593
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
594
|
+
)
|
|
595
|
+
perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
|
|
596
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
597
|
+
|
|
598
|
+
with self.profile("forward pass"):
|
|
599
|
+
combined_logits = self.compute_logits(module, combined_images, task)
|
|
600
|
+
logits = combined_logits[: images.size(0)]
|
|
601
|
+
logits_adv = combined_logits[images.size(0) :]
|
|
602
|
+
|
|
603
|
+
# # ### regu1
|
|
604
|
+
# ori_label = torch.argmax(logits, axis=1).long()
|
|
605
|
+
# loss_nat = entropy_loss(logits)
|
|
606
|
+
# loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
|
|
607
|
+
|
|
608
|
+
### regu2
|
|
609
|
+
loss_regu = entropy_loss(logits_adv)
|
|
610
|
+
loss_nat = entropy_loss(logits)
|
|
611
|
+
|
|
612
|
+
loss = loss_nat + self.config.adv_weight * loss_regu
|
|
613
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
614
|
+
|
|
615
|
+
with self.profile("compute grad"):
|
|
616
|
+
self.fabric.backward(total_loss)
|
|
617
|
+
|
|
618
|
+
with self.profile("optimizer step"):
|
|
619
|
+
optimizer.step()
|
|
620
|
+
optimizer.zero_grad()
|
|
621
|
+
|
|
622
|
+
if lr_scheduler is not None:
|
|
623
|
+
lr_scheduler.step()
|
|
624
|
+
|
|
625
|
+
# metrics.update({"train/loss": loss.item()})
|
|
626
|
+
metrics.update(
|
|
627
|
+
{
|
|
628
|
+
"train/loss": loss.item(),
|
|
629
|
+
"loss_nat": loss_nat.item(),
|
|
630
|
+
"loss_regu": loss_regu.item(),
|
|
631
|
+
}
|
|
632
|
+
)
|
|
633
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
634
|
+
pbar.set_postfix(metrics)
|
|
635
|
+
self.print_profile_summary()
|
|
636
|
+
|
|
637
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
638
|
+
with self.profiler.profile("save checkpoint"):
|
|
639
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
640
|
+
if not os.path.exists(save_dir):
|
|
641
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
642
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
643
|
+
print(f"saving checkpoint to {save_path}")
|
|
644
|
+
state = {"model": mask_model}
|
|
645
|
+
self.fabric.save(save_path, state)
|
|
646
|
+
|
|
647
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
648
|
+
if self.fabric.is_global_zero:
|
|
649
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
650
|
+
if os.path.exists(symlink_path):
|
|
651
|
+
os.remove(symlink_path)
|
|
652
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
653
|
+
|
|
654
|
+
self.print_profile_summary()
|
|
655
|
+
|
|
656
|
+
def run_adamerging(self, module: LayerWiseMergedModel, mask):
|
|
657
|
+
module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
|
|
658
|
+
base_optimizer = torch.optim.Adam(
|
|
659
|
+
[module.merge_weight], lr=self.config.adamerging_lr
|
|
660
|
+
)
|
|
661
|
+
module, base_optimizer = self.fabric.setup(module, base_optimizer)
|
|
662
|
+
module.train()
|
|
663
|
+
for step_idx in (
|
|
664
|
+
pbar := tqdm(
|
|
665
|
+
range(
|
|
666
|
+
self.config.max_adamerging_steps if not self.is_debug_mode else 5
|
|
667
|
+
),
|
|
668
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
669
|
+
+ "Concrete AdaMerging AdaMerging (2/2)",
|
|
670
|
+
dynamic_ncols=True,
|
|
671
|
+
disable=not self.fabric.is_global_zero,
|
|
672
|
+
)
|
|
673
|
+
):
|
|
674
|
+
step_idx = step_idx + self.config.max_steps
|
|
675
|
+
with self.profile("merge weights"):
|
|
676
|
+
module.merge_weights(task_vector_mask=mask)
|
|
677
|
+
|
|
678
|
+
metrics = {}
|
|
679
|
+
total_loss = None
|
|
680
|
+
for task in self.modelpool.model_names:
|
|
681
|
+
with self.profile("data loading"):
|
|
682
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
683
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
684
|
+
images = batch[0]
|
|
685
|
+
with self.profile("forward pass"):
|
|
686
|
+
logits = self.compute_logits(module, images, task)
|
|
687
|
+
loss = entropy_loss(logits)
|
|
688
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
689
|
+
|
|
690
|
+
with self.profile("compute grad"):
|
|
691
|
+
self.fabric.backward(total_loss)
|
|
692
|
+
|
|
693
|
+
with self.profile("base optimizer step"):
|
|
694
|
+
base_optimizer.step()
|
|
695
|
+
base_optimizer.zero_grad()
|
|
696
|
+
|
|
697
|
+
metrics.update({"train/loss": loss.item()})
|
|
698
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
699
|
+
pbar.set_postfix(metrics)
|
|
700
|
+
|
|
701
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
702
|
+
with self.profiler.profile("save checkpoint"):
|
|
703
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
704
|
+
if not os.path.exists(save_dir):
|
|
705
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
706
|
+
save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
|
|
707
|
+
print(f"saving checkpoint to {save_path}")
|
|
708
|
+
state = {"merge_weight": module.merge_weight}
|
|
709
|
+
self.fabric.save(save_path, state)
|
|
710
|
+
|
|
711
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
712
|
+
if self.fabric.is_global_zero:
|
|
713
|
+
symlink_path = os.path.join(
|
|
714
|
+
save_dir, "merge_weight_latest_checkpoint.pt"
|
|
715
|
+
)
|
|
716
|
+
if os.path.exists(symlink_path):
|
|
717
|
+
os.remove(symlink_path)
|
|
718
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
719
|
+
|
|
720
|
+
self.print_profile_summary()
|
|
721
|
+
return module
|
|
722
|
+
|
|
723
|
+
# def run_adamerging(self, module: LayerWiseMergedModel, mask):
|
|
724
|
+
# module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
|
|
725
|
+
# base_optimizer = torch.optim.Adam(
|
|
726
|
+
# [module.merge_weight], lr=self.config.adamerging_lr
|
|
727
|
+
# )
|
|
728
|
+
# module, base_optimizer = self.fabric.setup(module, base_optimizer)
|
|
729
|
+
# module.train()
|
|
730
|
+
# for step_idx in (
|
|
731
|
+
# pbar := tqdm(
|
|
732
|
+
# range(
|
|
733
|
+
# self.config.max_adamerging_steps if not self.is_debug_mode else 5
|
|
734
|
+
# ),
|
|
735
|
+
# ("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
736
|
+
# + "Concrete AdaMerging AdaMerging (2/2)",
|
|
737
|
+
# dynamic_ncols=True,
|
|
738
|
+
# disable=not self.fabric.is_global_zero,
|
|
739
|
+
# )
|
|
740
|
+
# ):
|
|
741
|
+
# step_idx = step_idx + self.config.max_steps
|
|
742
|
+
# with self.profile("merge weights"):
|
|
743
|
+
# module.merge_weights(task_vector_mask=mask)
|
|
744
|
+
|
|
745
|
+
# metrics = {}
|
|
746
|
+
# total_loss = None
|
|
747
|
+
# for task_idx, task in enumerate(self.modelpool.model_names):
|
|
748
|
+
# with self.profile("data loading"), torch.no_grad():
|
|
749
|
+
# batch = next(self.get_shuffled_test_loader_iter(task))
|
|
750
|
+
# # NOTE: The labels are not allowed to be used during test-time adaptation
|
|
751
|
+
# images = batch[0]
|
|
752
|
+
# perturbed_images = images + self.pertubed_model.perturbed_input[task_idx]
|
|
753
|
+
# perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
|
|
754
|
+
# combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
755
|
+
|
|
756
|
+
# with self.profile("forward pass"):
|
|
757
|
+
# combined_logits = self.compute_logits(module, combined_images, task)
|
|
758
|
+
# logits = combined_logits[:images.size(0)]
|
|
759
|
+
# logits_adv = combined_logits[images.size(0):]
|
|
760
|
+
|
|
761
|
+
# # # ### regu1
|
|
762
|
+
# # ori_label = torch.argmax(logits, axis=1).long()
|
|
763
|
+
# # loss_nat = entropy_loss(logits)
|
|
764
|
+
# # loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
|
|
765
|
+
|
|
766
|
+
# ### regu2
|
|
767
|
+
# loss_regu = entropy_loss(logits_adv)
|
|
768
|
+
# loss_nat = entropy_loss(logits)
|
|
769
|
+
|
|
770
|
+
# loss = loss_nat + self.config.adv_weight*loss_regu
|
|
771
|
+
# total_loss = loss if total_loss is None else total_loss + loss
|
|
772
|
+
# metrics.update({"train/loss": loss.item(),"loss_nat": loss_nat.item(),"loss_regu": loss_regu.item()})
|
|
773
|
+
|
|
774
|
+
# self.fabric.log_dict(metrics, step=step_idx)
|
|
775
|
+
# pbar.set_postfix(metrics)
|
|
776
|
+
# self.print_profile_summary()
|
|
777
|
+
|
|
778
|
+
# if (step_idx + 1) % self.config.save_interval == 0:
|
|
779
|
+
# with self.profiler.profile("save checkpoint"):
|
|
780
|
+
# save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
781
|
+
# if not os.path.exists(save_dir):
|
|
782
|
+
# os.makedirs(save_dir, exist_ok=True)
|
|
783
|
+
# save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
|
|
784
|
+
# print(f"saving checkpoint to {save_path}")
|
|
785
|
+
# state = {"merge_weight": module.merge_weight}
|
|
786
|
+
# self.fabric.save(save_path, state)
|
|
787
|
+
|
|
788
|
+
# # Create or update a symbolic link to the latest checkpoint
|
|
789
|
+
# if self.fabric.is_global_zero:
|
|
790
|
+
# symlink_path = os.path.join(
|
|
791
|
+
# save_dir, "merge_weight_latest_checkpoint.pt"
|
|
792
|
+
# )
|
|
793
|
+
# if os.path.exists(symlink_path):
|
|
794
|
+
# os.remove(symlink_path)
|
|
795
|
+
# os.link(os.path.abspath(save_path), symlink_path)
|
|
796
|
+
|
|
797
|
+
# self.print_profile_summary()
|
|
798
|
+
# return module
|
|
799
|
+
|
|
800
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
801
|
+
self.modelpool = to_modelpool(modelpool)
|
|
802
|
+
config = self.config
|
|
803
|
+
self.log_hyperparams(config, filename="method_config.yaml")
|
|
804
|
+
|
|
805
|
+
with self.profile("setup models"):
|
|
806
|
+
module, mask_model = self.setup_models()
|
|
807
|
+
mask_model: MaskModel = self.fabric.to_device(mask_model)
|
|
808
|
+
module: LayerWiseMergedModel = self.fabric.to_device(module)
|
|
809
|
+
self.pertubed_model = self.fabric.to_device(self.pertubed_model)
|
|
810
|
+
self.setup_zero_shot_classification_head()
|
|
811
|
+
|
|
812
|
+
if config.mask_checkpoint is None:
|
|
813
|
+
self.train_mask(module=module, mask_model=mask_model)
|
|
814
|
+
else:
|
|
815
|
+
if self.fabric.is_global_zero:
|
|
816
|
+
print("loading mask from checkpoint", config.mask_checkpoint)
|
|
817
|
+
self.fabric.load(config.mask_checkpoint, {"model": mask_model})
|
|
818
|
+
|
|
819
|
+
# run adamerging
|
|
820
|
+
with torch.no_grad():
|
|
821
|
+
mask = mask_model.sample_mask(
|
|
822
|
+
mask_type=config.eval_mask_type,
|
|
823
|
+
temperature=config.temperature,
|
|
824
|
+
)
|
|
825
|
+
# rescale mask
|
|
826
|
+
for name, m in mask.items():
|
|
827
|
+
mask[name] = m / torch.mean(m)
|
|
828
|
+
module = self.run_adamerging(module, mask=mask)
|
|
829
|
+
|
|
830
|
+
with torch.no_grad():
|
|
831
|
+
model = module.merge_and_unload(mask)
|
|
832
|
+
return model
|