fusion-bench 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +5 -0
- fusion_bench/dataset/fer2013.py +0 -1
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/concrete_subspace/__init__.py +8 -0
- fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
- fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +46 -0
- fusion_bench/method/doge_ta/doge_ta.py +364 -0
- fusion_bench/method/doge_ta/layer_wise_adamerging.py +250 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/isotropic_merging/iso.py +2 -2
- fusion_bench/method/opcm/opcm.py +93 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/task_singular_vector/TSVM.py +3 -3
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +416 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/RECORD +32 -19
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/WHEEL +1 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
- fusion_bench_config/method/doge_ta/doge_ta.yaml +4 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -8
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +68 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,744 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Post-Defense Methods on the merged models (CLIP ViT)
|
|
3
|
+
|
|
4
|
+
Examples:
|
|
5
|
+
|
|
6
|
+
```bash
|
|
7
|
+
fusion_bench \
|
|
8
|
+
fabric.loggers.name= \
|
|
9
|
+
method=clip_post_defense_AWM \
|
|
10
|
+
modelpool= \
|
|
11
|
+
taskpool=
|
|
12
|
+
```
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
fusion_bench \
|
|
16
|
+
fabric.loggers.name= \
|
|
17
|
+
method=clip_post_defense_SAU \
|
|
18
|
+
modelpool= \
|
|
19
|
+
taskpool=
|
|
20
|
+
```
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
import os
|
|
25
|
+
from typing import cast
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
import torch.nn as nn
|
|
29
|
+
import torch.nn.functional as F
|
|
30
|
+
from tqdm.autonotebook import tqdm
|
|
31
|
+
|
|
32
|
+
from fusion_bench.compat.method import ModelFusionAlgorithm
|
|
33
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
34
|
+
from fusion_bench.compat.modelpool.huggingface_clip_vision import (
|
|
35
|
+
HuggingFaceClipVisionPool,
|
|
36
|
+
)
|
|
37
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
38
|
+
from fusion_bench.mixins import CLIPClassificationMixin
|
|
39
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
40
|
+
from fusion_bench.models.masks import MaskModel, mask_sparsity
|
|
41
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
42
|
+
LayerWiseMergedModel,
|
|
43
|
+
get_layer_wise_weights,
|
|
44
|
+
)
|
|
45
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
46
|
+
TaskWiseMergedModel,
|
|
47
|
+
get_task_wise_weights,
|
|
48
|
+
)
|
|
49
|
+
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
50
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
51
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
52
|
+
|
|
53
|
+
log = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class PostDefenseAWMAlgorithmForCLIP(
|
|
57
|
+
CLIPClassificationMixin,
|
|
58
|
+
SimpleProfilerMixin,
|
|
59
|
+
ModelFusionAlgorithm,
|
|
60
|
+
):
|
|
61
|
+
@torch.no_grad()
|
|
62
|
+
def setup_models(self):
|
|
63
|
+
config = self.config
|
|
64
|
+
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
|
|
65
|
+
modelpool = self.modelpool
|
|
66
|
+
|
|
67
|
+
# Load the pretrained model
|
|
68
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
69
|
+
merge_model = modelpool.load_model("merge")
|
|
70
|
+
|
|
71
|
+
# construct PGE mask model
|
|
72
|
+
mask_model = MaskModel(
|
|
73
|
+
merge_model,
|
|
74
|
+
ignore_untrained_params=True,
|
|
75
|
+
parameter_type="logits",
|
|
76
|
+
)
|
|
77
|
+
if self.merge_dtype is not None:
|
|
78
|
+
mask_model.to(self.merge_dtype)
|
|
79
|
+
mask_model.fill_(self.config.initial_logits)
|
|
80
|
+
# TODO: ablation study for the initialization of mask model
|
|
81
|
+
# for param in mask_model.parameters():
|
|
82
|
+
# param.data = param + 0.1 * torch.randn_like(param)
|
|
83
|
+
print("Summary of mask model:")
|
|
84
|
+
print_parameters(mask_model)
|
|
85
|
+
|
|
86
|
+
self.pertubed_model = nn.Module()
|
|
87
|
+
self.pertubed_model.perturbed_input = nn.Parameter(
|
|
88
|
+
torch.zeros([len(self.modelpool.config.tta_datasets), 3, 224, 224]),
|
|
89
|
+
requires_grad=True,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return merge_model, pretrained_model, mask_model
|
|
93
|
+
|
|
94
|
+
def train_mask(self, merge_model, pretrained_model, mask_model: MaskModel):
|
|
95
|
+
config = self.config
|
|
96
|
+
|
|
97
|
+
# configure optimizer
|
|
98
|
+
lr_scheduler = None
|
|
99
|
+
if self.config.optimizer == "adam":
|
|
100
|
+
optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
|
|
101
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
102
|
+
|
|
103
|
+
batch_opt_adv = torch.optim.Adam(
|
|
104
|
+
params=self.pertubed_model.parameters(), lr=self.config.adv_lr
|
|
105
|
+
)
|
|
106
|
+
self.pertubed_model, batch_opt_adv = self.fabric.setup(
|
|
107
|
+
self.pertubed_model, batch_opt_adv
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
merge_model.requires_grad_(False)
|
|
111
|
+
pretrained_model.requires_grad_(False)
|
|
112
|
+
|
|
113
|
+
mask_model.train()
|
|
114
|
+
optimizer.zero_grad()
|
|
115
|
+
|
|
116
|
+
self.pertubed_model.train()
|
|
117
|
+
batch_opt_adv.zero_grad()
|
|
118
|
+
# torch.autograd.set_detect_anomaly(True)
|
|
119
|
+
|
|
120
|
+
pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
|
|
121
|
+
for step_idx in (
|
|
122
|
+
pbar := tqdm(
|
|
123
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
124
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
125
|
+
+ "clip_post_defense_AWM",
|
|
126
|
+
dynamic_ncols=True,
|
|
127
|
+
disable=not self.fabric.is_global_zero,
|
|
128
|
+
)
|
|
129
|
+
):
|
|
130
|
+
metrics = {}
|
|
131
|
+
# sample a shared mask and merge weights
|
|
132
|
+
with self.profile("sample mask"):
|
|
133
|
+
mask = mask_model.sample_mask(
|
|
134
|
+
mask_type="continuous", temperature=config.temperature
|
|
135
|
+
)
|
|
136
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
137
|
+
with self.profile("merge weights"):
|
|
138
|
+
# rescale mask
|
|
139
|
+
for name, m in mask.items():
|
|
140
|
+
mask[name] = m / torch.mean(m)
|
|
141
|
+
|
|
142
|
+
merged_state_dict = merge_model.state_dict(keep_vars=True)
|
|
143
|
+
for name, parameter in merged_state_dict.items():
|
|
144
|
+
## (1) mask--directly prune the merged model, the initial logits should be larger than 3
|
|
145
|
+
# merged_state_dict[name] = merged_state_dict[name]* mask[name]
|
|
146
|
+
### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
|
|
147
|
+
merged_state_dict[name] = (
|
|
148
|
+
merged_state_dict[name] - pretrained_model_dict[name]
|
|
149
|
+
) * mask[name] + pretrained_model_dict[name]
|
|
150
|
+
|
|
151
|
+
# ------ noise optimization based on the merging model ------
|
|
152
|
+
# detach merged state_dict
|
|
153
|
+
detached_merged_state_dict = {
|
|
154
|
+
k: p.detach() for k, p in merged_state_dict.items()
|
|
155
|
+
}
|
|
156
|
+
merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
|
|
157
|
+
merge_model, detached_merged_state_dict, args=args, kwargs=kwargs
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
total_loss = None
|
|
161
|
+
for task_idx, task in enumerate(
|
|
162
|
+
[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
163
|
+
):
|
|
164
|
+
with self.profile("data loading"):
|
|
165
|
+
batch = next(
|
|
166
|
+
self.get_shuffled_test_loader_iter(task)
|
|
167
|
+
) # image ,label
|
|
168
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
169
|
+
# batch[0],batch[1] = batch[0].to(merge_model.device),batch[1].to(merge_model.device)
|
|
170
|
+
images = batch[0]
|
|
171
|
+
perturbed_images = (
|
|
172
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
173
|
+
)
|
|
174
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
175
|
+
|
|
176
|
+
with self.profile("forward pass"):
|
|
177
|
+
combined_logits = self.compute_logits(
|
|
178
|
+
merge_model_forward, combined_images, task
|
|
179
|
+
)
|
|
180
|
+
# print(combined_logits.size())
|
|
181
|
+
num_image = images.size(0)
|
|
182
|
+
logits, logits_adv = (
|
|
183
|
+
combined_logits[:num_image],
|
|
184
|
+
combined_logits[num_image:],
|
|
185
|
+
)
|
|
186
|
+
ori_label = torch.argmax(logits, dim=1).long()
|
|
187
|
+
|
|
188
|
+
loss = torch.mean(
|
|
189
|
+
-F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
190
|
+
)
|
|
191
|
+
# print(loss)
|
|
192
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
193
|
+
|
|
194
|
+
with self.profile("compute grad"):
|
|
195
|
+
self.fabric.backward(total_loss)
|
|
196
|
+
|
|
197
|
+
with self.profile("batch_opt_adv optimizer step"):
|
|
198
|
+
batch_opt_adv.step()
|
|
199
|
+
batch_opt_adv.zero_grad()
|
|
200
|
+
|
|
201
|
+
# ------ inner optimization goes here ------
|
|
202
|
+
# NOTE:
|
|
203
|
+
# Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
|
|
204
|
+
# set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
|
|
205
|
+
# -----------------------------------------
|
|
206
|
+
### mask optimization
|
|
207
|
+
merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
|
|
208
|
+
merge_model, merged_state_dict, args=args, kwargs=kwargs
|
|
209
|
+
)
|
|
210
|
+
total_loss = None
|
|
211
|
+
|
|
212
|
+
# trigger_norm = self.config.trigger_norm
|
|
213
|
+
# pert = batch_pert * min(1, trigger_norm / torch.sum(torch.abs(batch_pert)))
|
|
214
|
+
# pert = pert.detach()
|
|
215
|
+
|
|
216
|
+
for task_idx, task in enumerate(
|
|
217
|
+
[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
218
|
+
):
|
|
219
|
+
with self.profile("data loading"), torch.no_grad():
|
|
220
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
221
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
222
|
+
images = batch[0]
|
|
223
|
+
|
|
224
|
+
# perturbed_images = images + self.pertubed_model.perturbed_input[task_idx]
|
|
225
|
+
# perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
|
|
226
|
+
|
|
227
|
+
perturbed_images = torch.clamp(
|
|
228
|
+
images + self.pertubed_model.perturbed_input[task_idx],
|
|
229
|
+
min=0,
|
|
230
|
+
max=1,
|
|
231
|
+
)
|
|
232
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
233
|
+
|
|
234
|
+
with self.profile("forward pass"):
|
|
235
|
+
combined_logits = self.compute_logits(
|
|
236
|
+
merge_model_forward, combined_images, task
|
|
237
|
+
)
|
|
238
|
+
num_image = images.size(0)
|
|
239
|
+
logits, logits_adv = (
|
|
240
|
+
combined_logits[:num_image],
|
|
241
|
+
combined_logits[num_image:],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
loss_nat = entropy_loss(logits)
|
|
245
|
+
|
|
246
|
+
# ### regu1
|
|
247
|
+
# ori_label = torch.argmax(logits, dim=1).long()
|
|
248
|
+
# loss_regu = -torch.mean(
|
|
249
|
+
# F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
250
|
+
# )
|
|
251
|
+
|
|
252
|
+
### regu2
|
|
253
|
+
loss_regu = entropy_loss(logits_adv)
|
|
254
|
+
|
|
255
|
+
loss = loss_nat + self.config.adv_weight * loss_regu
|
|
256
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
257
|
+
|
|
258
|
+
# loss = entropy_loss(logits)
|
|
259
|
+
# total_loss = loss if total_loss is None else total_loss + loss
|
|
260
|
+
|
|
261
|
+
with self.profile("compute grad"):
|
|
262
|
+
self.fabric.backward(total_loss)
|
|
263
|
+
|
|
264
|
+
with self.profile("optimizer step"):
|
|
265
|
+
optimizer.step()
|
|
266
|
+
optimizer.zero_grad()
|
|
267
|
+
|
|
268
|
+
if lr_scheduler is not None:
|
|
269
|
+
lr_scheduler.step()
|
|
270
|
+
# metrics.update({"train/loss": loss.item()})
|
|
271
|
+
metrics.update(
|
|
272
|
+
{
|
|
273
|
+
"train/loss": loss.item(),
|
|
274
|
+
"loss_nat": loss_nat.item(),
|
|
275
|
+
"loss_regu": loss_regu.item(),
|
|
276
|
+
}
|
|
277
|
+
)
|
|
278
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
279
|
+
pbar.set_postfix(metrics)
|
|
280
|
+
self.print_profile_summary()
|
|
281
|
+
|
|
282
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
283
|
+
with self.profiler.profile("save checkpoint"):
|
|
284
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
285
|
+
if not os.path.exists(save_dir):
|
|
286
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
287
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
288
|
+
print(f"saving checkpoint to {save_path}")
|
|
289
|
+
state = {"model": mask_model}
|
|
290
|
+
self.fabric.save(save_path, state)
|
|
291
|
+
|
|
292
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
293
|
+
if self.fabric.is_global_zero:
|
|
294
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
295
|
+
if os.path.exists(symlink_path):
|
|
296
|
+
os.remove(symlink_path)
|
|
297
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
298
|
+
|
|
299
|
+
self.print_profile_summary()
|
|
300
|
+
|
|
301
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
302
|
+
self.modelpool = to_modelpool(modelpool)
|
|
303
|
+
config = self.config
|
|
304
|
+
self.log_hyperparams(config, filename="method_config.yaml")
|
|
305
|
+
|
|
306
|
+
with self.profile("setup models"):
|
|
307
|
+
merge_model, pretrained_model, mask_model = self.setup_models()
|
|
308
|
+
mask_model: MaskModel = self.fabric.to_device(mask_model)
|
|
309
|
+
merge_model = self.fabric.to_device(merge_model)
|
|
310
|
+
pretrained_model = self.fabric.to_device(pretrained_model)
|
|
311
|
+
self.pertubed_model = self.fabric.to_device(self.pertubed_model)
|
|
312
|
+
self.setup_zero_shot_classification_head(
|
|
313
|
+
task_names=[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if config.mask_checkpoint is None:
|
|
317
|
+
self.train_mask(
|
|
318
|
+
merge_model=merge_model,
|
|
319
|
+
pretrained_model=pretrained_model,
|
|
320
|
+
mask_model=mask_model,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
if self.fabric.is_global_zero:
|
|
324
|
+
print("loading mask from checkpoint", config.mask_checkpoint)
|
|
325
|
+
self.fabric.load(config.mask_checkpoint, {"model": mask_model})
|
|
326
|
+
|
|
327
|
+
with torch.no_grad():
|
|
328
|
+
if torch.cuda.is_available():
|
|
329
|
+
torch.cuda.empty_cache()
|
|
330
|
+
mask = mask_model.sample_mask(
|
|
331
|
+
mask_type=config.eval_mask_type,
|
|
332
|
+
temperature=config.temperature,
|
|
333
|
+
)
|
|
334
|
+
# rescale mask
|
|
335
|
+
for name, m in mask.items():
|
|
336
|
+
mask[name] = m / torch.mean(m)
|
|
337
|
+
pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
|
|
338
|
+
merged_state_dict = merge_model.state_dict(keep_vars=True)
|
|
339
|
+
for name, parameter in merged_state_dict.items():
|
|
340
|
+
## (1) mask--directly prune the merged model, the initial logits should be larger than 3
|
|
341
|
+
# merged_state_dict[name] = merged_state_dict[name]* mask[name]
|
|
342
|
+
### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
|
|
343
|
+
merged_state_dict[name] = (
|
|
344
|
+
merged_state_dict[name] - pretrained_model_dict[name]
|
|
345
|
+
) * mask[name] + pretrained_model_dict[name]
|
|
346
|
+
merge_model.load_state_dict(merged_state_dict)
|
|
347
|
+
return merge_model
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class PostDefenseSAUAlgorithmForCLIP(
|
|
351
|
+
CLIPClassificationMixin,
|
|
352
|
+
SimpleProfilerMixin,
|
|
353
|
+
ModelFusionAlgorithm,
|
|
354
|
+
):
|
|
355
|
+
@torch.no_grad()
|
|
356
|
+
def setup_models(self):
|
|
357
|
+
config = self.config
|
|
358
|
+
self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
|
|
359
|
+
modelpool = self.modelpool
|
|
360
|
+
|
|
361
|
+
# Load the pretrained model
|
|
362
|
+
pretrained_model = modelpool.load_model("_pretrained_")
|
|
363
|
+
merge_model = modelpool.load_model("merge")
|
|
364
|
+
merge_model_ref = modelpool.load_model("merge")
|
|
365
|
+
|
|
366
|
+
# construct PGE mask model
|
|
367
|
+
mask_model = MaskModel(
|
|
368
|
+
merge_model,
|
|
369
|
+
ignore_untrained_params=True,
|
|
370
|
+
parameter_type="logits",
|
|
371
|
+
)
|
|
372
|
+
if self.merge_dtype is not None:
|
|
373
|
+
mask_model.to(self.merge_dtype)
|
|
374
|
+
mask_model.fill_(self.config.initial_logits)
|
|
375
|
+
# TODO: ablation study for the initialization of mask model
|
|
376
|
+
# for param in mask_model.parameters():
|
|
377
|
+
# param.data = param + 0.1 * torch.randn_like(param)
|
|
378
|
+
print("Summary of mask model:")
|
|
379
|
+
print_parameters(mask_model)
|
|
380
|
+
|
|
381
|
+
self.pertubed_model = nn.Module()
|
|
382
|
+
self.pertubed_model.perturbed_input = nn.Parameter(
|
|
383
|
+
torch.zeros([len(self.modelpool.config.tta_datasets), 3, 224, 224]),
|
|
384
|
+
requires_grad=True,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
return merge_model, merge_model_ref, pretrained_model, mask_model
|
|
388
|
+
|
|
389
|
+
def train_mask(
|
|
390
|
+
self, merge_model, merge_model_ref, pretrained_model, mask_model: MaskModel
|
|
391
|
+
):
|
|
392
|
+
config = self.config
|
|
393
|
+
|
|
394
|
+
# configure optimizer
|
|
395
|
+
lr_scheduler = None
|
|
396
|
+
if self.config.optimizer == "adam":
|
|
397
|
+
optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
|
|
398
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
399
|
+
|
|
400
|
+
batch_opt_adv = torch.optim.Adam(
|
|
401
|
+
params=self.pertubed_model.parameters(), lr=self.config.adv_lr
|
|
402
|
+
)
|
|
403
|
+
self.pertubed_model, batch_opt_adv = self.fabric.setup(
|
|
404
|
+
self.pertubed_model, batch_opt_adv
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
merge_model.requires_grad_(False)
|
|
408
|
+
merge_model_ref.requires_grad_(False)
|
|
409
|
+
pretrained_model.requires_grad_(False)
|
|
410
|
+
|
|
411
|
+
mask_model.train()
|
|
412
|
+
optimizer.zero_grad()
|
|
413
|
+
|
|
414
|
+
self.pertubed_model.train()
|
|
415
|
+
batch_opt_adv.zero_grad()
|
|
416
|
+
# torch.autograd.set_detect_anomaly(True)
|
|
417
|
+
|
|
418
|
+
pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
|
|
419
|
+
for step_idx in (
|
|
420
|
+
pbar := tqdm(
|
|
421
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
422
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
423
|
+
+ "clip_post_defense_SAU",
|
|
424
|
+
dynamic_ncols=True,
|
|
425
|
+
disable=not self.fabric.is_global_zero,
|
|
426
|
+
)
|
|
427
|
+
):
|
|
428
|
+
metrics = {}
|
|
429
|
+
# sample a shared mask and merge weights
|
|
430
|
+
with self.profile("sample mask"):
|
|
431
|
+
mask = mask_model.sample_mask(
|
|
432
|
+
mask_type="continuous", temperature=config.temperature
|
|
433
|
+
)
|
|
434
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
435
|
+
with self.profile("merge weights"):
|
|
436
|
+
# rescale mask
|
|
437
|
+
for name, m in mask.items():
|
|
438
|
+
mask[name] = m / torch.mean(m)
|
|
439
|
+
|
|
440
|
+
merged_state_dict = merge_model.state_dict(keep_vars=True)
|
|
441
|
+
for name, parameter in merged_state_dict.items():
|
|
442
|
+
## (1) directy mask/prune the merged model, the initial logits should be larger than 3, without decreasing the acc
|
|
443
|
+
# merged_state_dict[name] = merged_state_dict[name]* mask[name]
|
|
444
|
+
### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
|
|
445
|
+
merged_state_dict[name] = (
|
|
446
|
+
merged_state_dict[name] - pretrained_model_dict[name]
|
|
447
|
+
) * mask[name] + pretrained_model_dict[name]
|
|
448
|
+
|
|
449
|
+
# ------ noise optimization based on the merging model ------
|
|
450
|
+
# detach merged state_dict
|
|
451
|
+
detached_merged_state_dict = {
|
|
452
|
+
k: p.detach() for k, p in merged_state_dict.items()
|
|
453
|
+
}
|
|
454
|
+
merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
|
|
455
|
+
merge_model, detached_merged_state_dict, args=args, kwargs=kwargs
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
total_loss = None
|
|
459
|
+
for task_idx, task in enumerate(
|
|
460
|
+
[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
461
|
+
):
|
|
462
|
+
with self.profile("data loading"):
|
|
463
|
+
batch = next(
|
|
464
|
+
self.get_shuffled_test_loader_iter(task)
|
|
465
|
+
) # image ,label
|
|
466
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
467
|
+
# batch[0],batch[1] = batch[0].to(merge_model.device),batch[1].to(merge_model.device)
|
|
468
|
+
images = batch[0]
|
|
469
|
+
perturbed_images = (
|
|
470
|
+
images + self.pertubed_model.perturbed_input[task_idx]
|
|
471
|
+
)
|
|
472
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
473
|
+
|
|
474
|
+
with self.profile("forward pass"):
|
|
475
|
+
num_image = images.size(0)
|
|
476
|
+
|
|
477
|
+
combined_logits = self.compute_logits(
|
|
478
|
+
merge_model_forward, combined_images, task
|
|
479
|
+
)
|
|
480
|
+
logits, logits_adv = (
|
|
481
|
+
combined_logits[:num_image],
|
|
482
|
+
combined_logits[num_image:],
|
|
483
|
+
)
|
|
484
|
+
ori_label = torch.argmax(logits, dim=1).long()
|
|
485
|
+
pert_label = torch.argmax(logits_adv, dim=1).long()
|
|
486
|
+
|
|
487
|
+
combined_logits_ref = self.compute_logits(
|
|
488
|
+
merge_model_ref, combined_images, task
|
|
489
|
+
)
|
|
490
|
+
logits_ref, logits_adv_ref = (
|
|
491
|
+
combined_logits_ref[:num_image],
|
|
492
|
+
combined_logits_ref[num_image:],
|
|
493
|
+
)
|
|
494
|
+
ori_label_ref = torch.argmax(logits_ref, dim=1).long()
|
|
495
|
+
pert_label_ref = torch.argmax(logits_adv_ref, dim=1).long()
|
|
496
|
+
|
|
497
|
+
success_attack = pert_label != ori_label
|
|
498
|
+
success_attack_ref = pert_label_ref != ori_label_ref
|
|
499
|
+
common_attack = torch.logical_and(
|
|
500
|
+
success_attack, success_attack_ref
|
|
501
|
+
)
|
|
502
|
+
shared_attack = torch.logical_and(
|
|
503
|
+
common_attack, pert_label == pert_label_ref
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
# Shared loss
|
|
507
|
+
# JS divergence version (https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence)
|
|
508
|
+
p_model = F.softmax(logits_adv, dim=1).clamp(min=1e-8)
|
|
509
|
+
p_ref = F.softmax(logits_adv_ref, dim=1).clamp(min=1e-8)
|
|
510
|
+
mix_p = 0.5 * (p_model + p_ref)
|
|
511
|
+
loss_js = 0.5 * (
|
|
512
|
+
p_model * p_model.log() + p_ref * p_ref.log()
|
|
513
|
+
) - 0.5 * (p_model * mix_p.log() + p_ref * mix_p.log())
|
|
514
|
+
loss_cross = (
|
|
515
|
+
loss_js[torch.logical_not(shared_attack)].sum(dim=1).sum()
|
|
516
|
+
/ images.shape[0]
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
### maximization perturbation loss
|
|
520
|
+
## using the test data without the true label
|
|
521
|
+
loss_adv = torch.mean(
|
|
522
|
+
-F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
loss = self.config.beta1 * loss_adv + self.config.beta2 * loss_cross
|
|
526
|
+
|
|
527
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
528
|
+
|
|
529
|
+
with self.profile("compute grad"):
|
|
530
|
+
self.fabric.backward(total_loss)
|
|
531
|
+
|
|
532
|
+
with self.profile("batch_opt_adv optimizer step"):
|
|
533
|
+
batch_opt_adv.step()
|
|
534
|
+
batch_opt_adv.zero_grad()
|
|
535
|
+
|
|
536
|
+
# ------ inner optimization goes here ------
|
|
537
|
+
# NOTE:
|
|
538
|
+
# Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
|
|
539
|
+
# set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
|
|
540
|
+
# -----------------------------------------
|
|
541
|
+
### mask optimization
|
|
542
|
+
merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
|
|
543
|
+
merge_model, merged_state_dict, args=args, kwargs=kwargs
|
|
544
|
+
)
|
|
545
|
+
total_loss = None
|
|
546
|
+
|
|
547
|
+
# trigger_norm = self.config.trigger_norm
|
|
548
|
+
# pert = batch_pert * min(1, trigger_norm / torch.sum(torch.abs(batch_pert)))
|
|
549
|
+
# pert = pert.detach()
|
|
550
|
+
|
|
551
|
+
for task_idx, task in enumerate(
|
|
552
|
+
[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
553
|
+
):
|
|
554
|
+
classnames, templates = get_classnames_and_templates(
|
|
555
|
+
self.modelpool.get_train_dataset_config(task)["dataset"].name
|
|
556
|
+
)
|
|
557
|
+
num_classes = len(classnames)
|
|
558
|
+
|
|
559
|
+
with self.profile("data loading"), torch.no_grad():
|
|
560
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
561
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
562
|
+
images = batch[0]
|
|
563
|
+
|
|
564
|
+
perturbed_images = torch.clamp(
|
|
565
|
+
images + self.pertubed_model.perturbed_input[task_idx],
|
|
566
|
+
min=0,
|
|
567
|
+
max=1,
|
|
568
|
+
)
|
|
569
|
+
combined_images = torch.cat((images, perturbed_images), dim=0)
|
|
570
|
+
|
|
571
|
+
with self.profile("forward pass"):
|
|
572
|
+
|
|
573
|
+
num_image = images.size(0)
|
|
574
|
+
|
|
575
|
+
### loss_nat
|
|
576
|
+
combined_logits = self.compute_logits(
|
|
577
|
+
merge_model_forward, combined_images, task
|
|
578
|
+
)
|
|
579
|
+
logits, logits_adv = (
|
|
580
|
+
combined_logits[:num_image],
|
|
581
|
+
combined_logits[num_image:],
|
|
582
|
+
)
|
|
583
|
+
ori_label = torch.argmax(logits, dim=1).long()
|
|
584
|
+
pert_label = torch.argmax(logits_adv, dim=1).long()
|
|
585
|
+
loss_nat = entropy_loss(logits)
|
|
586
|
+
|
|
587
|
+
########### loss_regu from noise
|
|
588
|
+
### regu1
|
|
589
|
+
# ori_label = torch.argmax(logits, dim=1).long()
|
|
590
|
+
# loss_regu = -torch.mean(
|
|
591
|
+
# F.cross_entropy(logits_adv, ori_label, reduction="mean")
|
|
592
|
+
# )
|
|
593
|
+
### regu2
|
|
594
|
+
loss_regu = entropy_loss(logits_adv)
|
|
595
|
+
|
|
596
|
+
### loss shared
|
|
597
|
+
combined_logits_ref = self.compute_logits(
|
|
598
|
+
merge_model_ref, combined_images, task
|
|
599
|
+
)
|
|
600
|
+
logits_ref, logits_adv_ref = (
|
|
601
|
+
combined_logits_ref[:num_image],
|
|
602
|
+
combined_logits_ref[num_image:],
|
|
603
|
+
)
|
|
604
|
+
ori_label_ref = torch.argmax(logits_ref, dim=1).long()
|
|
605
|
+
pert_label_ref = torch.argmax(logits_adv_ref, dim=1).long()
|
|
606
|
+
|
|
607
|
+
success_attack = pert_label != ori_label
|
|
608
|
+
|
|
609
|
+
#### due to fact that we only use the test data without true label there, we replace the true label with ori_label
|
|
610
|
+
success_attack_ref = pert_label_ref != ori_label
|
|
611
|
+
success_attack_ref = success_attack_ref & (
|
|
612
|
+
pert_label_ref != ori_label_ref
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
common_attack = torch.logical_and(
|
|
616
|
+
success_attack, success_attack_ref
|
|
617
|
+
)
|
|
618
|
+
shared_attack = torch.logical_and(
|
|
619
|
+
common_attack, pert_label == pert_label_ref
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
potential_poison = success_attack_ref
|
|
623
|
+
if potential_poison.sum() == 0:
|
|
624
|
+
loss_shared = torch.tensor(0.0).to(merge_model.device)
|
|
625
|
+
else:
|
|
626
|
+
one_hot = F.one_hot(pert_label_ref, num_classes=num_classes)
|
|
627
|
+
|
|
628
|
+
neg_one_hot = 1 - one_hot
|
|
629
|
+
neg_p = (F.softmax(logits_adv, dim=1) * neg_one_hot).sum(dim=1)[
|
|
630
|
+
potential_poison
|
|
631
|
+
]
|
|
632
|
+
pos_p = (F.softmax(logits_adv, dim=1) * one_hot).sum(dim=1)[
|
|
633
|
+
potential_poison
|
|
634
|
+
]
|
|
635
|
+
|
|
636
|
+
# clamp the too small values to avoid nan and discard samples with p<1% to be shared
|
|
637
|
+
# Note: The below equation combine two identical terms in math. Although they are the same in math, they are different in implementation due to the numerical issue.
|
|
638
|
+
# Combining them can reduce the numerical issue.
|
|
639
|
+
|
|
640
|
+
loss_shared = (
|
|
641
|
+
-torch.sum(torch.log(1e-6 + neg_p.clamp(max=0.999)))
|
|
642
|
+
- torch.sum(torch.log(1 + 1e-6 - pos_p.clamp(min=0.001)))
|
|
643
|
+
) / 2
|
|
644
|
+
loss_shared = loss_shared / images.shape[0]
|
|
645
|
+
|
|
646
|
+
loss = (
|
|
647
|
+
loss_nat
|
|
648
|
+
+ self.config.adv_weight * loss_regu
|
|
649
|
+
+ self.config.shared_weight * loss_shared
|
|
650
|
+
)
|
|
651
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
652
|
+
|
|
653
|
+
with self.profile("compute grad"):
|
|
654
|
+
self.fabric.backward(total_loss)
|
|
655
|
+
|
|
656
|
+
with self.profile("optimizer step"):
|
|
657
|
+
optimizer.step()
|
|
658
|
+
optimizer.zero_grad()
|
|
659
|
+
|
|
660
|
+
if lr_scheduler is not None:
|
|
661
|
+
lr_scheduler.step()
|
|
662
|
+
# metrics.update({"train/loss": loss.item()})
|
|
663
|
+
metrics.update(
|
|
664
|
+
{
|
|
665
|
+
"train/loss": loss.item(),
|
|
666
|
+
"loss_nat": loss_nat.item(),
|
|
667
|
+
"loss_regu": loss_regu.item(),
|
|
668
|
+
"loss_shared": loss_shared.item(),
|
|
669
|
+
}
|
|
670
|
+
)
|
|
671
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
672
|
+
pbar.set_postfix(metrics)
|
|
673
|
+
self.print_profile_summary()
|
|
674
|
+
|
|
675
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
676
|
+
with self.profiler.profile("save checkpoint"):
|
|
677
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
678
|
+
if not os.path.exists(save_dir):
|
|
679
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
680
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
681
|
+
print(f"saving checkpoint to {save_path}")
|
|
682
|
+
state = {"model": mask_model}
|
|
683
|
+
self.fabric.save(save_path, state)
|
|
684
|
+
|
|
685
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
686
|
+
if self.fabric.is_global_zero:
|
|
687
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
688
|
+
if os.path.exists(symlink_path):
|
|
689
|
+
os.remove(symlink_path)
|
|
690
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
691
|
+
|
|
692
|
+
self.print_profile_summary()
|
|
693
|
+
|
|
694
|
+
def run(self, modelpool: HuggingFaceClipVisionPool):
|
|
695
|
+
self.modelpool = to_modelpool(modelpool)
|
|
696
|
+
config = self.config
|
|
697
|
+
self.log_hyperparams(config, filename="method_config.yaml")
|
|
698
|
+
|
|
699
|
+
with self.profile("setup models"):
|
|
700
|
+
merge_model, merge_model_ref, pretrained_model, mask_model = (
|
|
701
|
+
self.setup_models()
|
|
702
|
+
)
|
|
703
|
+
mask_model: MaskModel = self.fabric.to_device(mask_model)
|
|
704
|
+
merge_model = self.fabric.to_device(merge_model)
|
|
705
|
+
merge_model_ref = self.fabric.to_device(merge_model_ref)
|
|
706
|
+
pretrained_model = self.fabric.to_device(pretrained_model)
|
|
707
|
+
self.pertubed_model = self.fabric.to_device(self.pertubed_model)
|
|
708
|
+
self.setup_zero_shot_classification_head(
|
|
709
|
+
task_names=[c["name"] for c in self.modelpool.config.tta_datasets]
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
if config.mask_checkpoint is None:
|
|
713
|
+
self.train_mask(
|
|
714
|
+
merge_model=merge_model,
|
|
715
|
+
merge_model_ref=merge_model_ref,
|
|
716
|
+
pretrained_model=pretrained_model,
|
|
717
|
+
mask_model=mask_model,
|
|
718
|
+
)
|
|
719
|
+
else:
|
|
720
|
+
if self.fabric.is_global_zero:
|
|
721
|
+
print("loading mask from checkpoint", config.mask_checkpoint)
|
|
722
|
+
self.fabric.load(config.mask_checkpoint, {"model": mask_model})
|
|
723
|
+
|
|
724
|
+
with torch.no_grad():
|
|
725
|
+
if torch.cuda.is_available():
|
|
726
|
+
torch.cuda.empty_cache()
|
|
727
|
+
mask = mask_model.sample_mask(
|
|
728
|
+
mask_type=config.eval_mask_type,
|
|
729
|
+
temperature=config.temperature,
|
|
730
|
+
)
|
|
731
|
+
# rescale mask
|
|
732
|
+
for name, m in mask.items():
|
|
733
|
+
mask[name] = m / torch.mean(m)
|
|
734
|
+
pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
|
|
735
|
+
merged_state_dict = merge_model.state_dict(keep_vars=True)
|
|
736
|
+
for name, parameter in merged_state_dict.items():
|
|
737
|
+
## (1) mask--directly prune the merged model, the initial logits should be larger than 3
|
|
738
|
+
# merged_state_dict[name] = merged_state_dict[name]* mask[name]
|
|
739
|
+
### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
|
|
740
|
+
merged_state_dict[name] = (
|
|
741
|
+
merged_state_dict[name] - pretrained_model_dict[name]
|
|
742
|
+
) * mask[name] + pretrained_model_dict[name]
|
|
743
|
+
merge_model.load_state_dict(merged_state_dict)
|
|
744
|
+
return merge_model
|