fusion-bench 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
  3. fusion_bench/method/DOGE_TA/__init__.py +2 -0
  4. fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
  5. fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
  6. fusion_bench/method/__init__.py +10 -0
  7. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  8. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  9. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  10. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  11. fusion_bench/method/isotropic_merging/iso.py +2 -2
  12. fusion_bench/method/task_singular_vector/TSVM.py +3 -3
  13. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
  14. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
  15. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +24 -12
  16. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
  17. fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
  18. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  19. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  20. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  21. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  22. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
  23. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
  24. {fusion_bench-0.2.10.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,744 @@
1
+ """
2
+ Post-Defense Methods on the merged models (CLIP ViT)
3
+
4
+ Examples:
5
+
6
+ ```bash
7
+ fusion_bench \
8
+ fabric.loggers.name= \
9
+ method=clip_post_defense_AWM \
10
+ modelpool= \
11
+ taskpool=
12
+ ```
13
+
14
+ ```bash
15
+ fusion_bench \
16
+ fabric.loggers.name= \
17
+ method=clip_post_defense_SAU \
18
+ modelpool= \
19
+ taskpool=
20
+ ```
21
+ """
22
+
23
+ import logging
24
+ import os
25
+ from typing import cast
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from tqdm.autonotebook import tqdm
31
+
32
+ from fusion_bench.compat.method import ModelFusionAlgorithm
33
+ from fusion_bench.compat.modelpool import to_modelpool
34
+ from fusion_bench.compat.modelpool.huggingface_clip_vision import (
35
+ HuggingFaceClipVisionPool,
36
+ )
37
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
38
+ from fusion_bench.mixins import CLIPClassificationMixin
39
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
40
+ from fusion_bench.models.masks import MaskModel, mask_sparsity
41
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
42
+ LayerWiseMergedModel,
43
+ get_layer_wise_weights,
44
+ )
45
+ from fusion_bench.models.wrappers.task_wise_fusion import (
46
+ TaskWiseMergedModel,
47
+ get_task_wise_weights,
48
+ )
49
+ from fusion_bench.tasks.clip_classification import get_classnames_and_templates
50
+ from fusion_bench.utils.dtype import parse_dtype
51
+ from fusion_bench.utils.parameters import print_parameters
52
+
53
+ log = logging.getLogger(__name__)
54
+
55
+
56
+ class PostDefenseAWMAlgorithmForCLIP(
57
+ CLIPClassificationMixin,
58
+ SimpleProfilerMixin,
59
+ ModelFusionAlgorithm,
60
+ ):
61
+ @torch.no_grad()
62
+ def setup_models(self):
63
+ config = self.config
64
+ self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
65
+ modelpool = self.modelpool
66
+
67
+ # Load the pretrained model
68
+ pretrained_model = modelpool.load_model("_pretrained_")
69
+ merge_model = modelpool.load_model("merge")
70
+
71
+ # construct PGE mask model
72
+ mask_model = MaskModel(
73
+ merge_model,
74
+ ignore_untrained_params=True,
75
+ parameter_type="logits",
76
+ )
77
+ if self.merge_dtype is not None:
78
+ mask_model.to(self.merge_dtype)
79
+ mask_model.fill_(self.config.initial_logits)
80
+ # TODO: ablation study for the initialization of mask model
81
+ # for param in mask_model.parameters():
82
+ # param.data = param + 0.1 * torch.randn_like(param)
83
+ print("Summary of mask model:")
84
+ print_parameters(mask_model)
85
+
86
+ self.pertubed_model = nn.Module()
87
+ self.pertubed_model.perturbed_input = nn.Parameter(
88
+ torch.zeros([len(self.modelpool.config.tta_datasets), 3, 224, 224]),
89
+ requires_grad=True,
90
+ )
91
+
92
+ return merge_model, pretrained_model, mask_model
93
+
94
+ def train_mask(self, merge_model, pretrained_model, mask_model: MaskModel):
95
+ config = self.config
96
+
97
+ # configure optimizer
98
+ lr_scheduler = None
99
+ if self.config.optimizer == "adam":
100
+ optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
101
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
102
+
103
+ batch_opt_adv = torch.optim.Adam(
104
+ params=self.pertubed_model.parameters(), lr=self.config.adv_lr
105
+ )
106
+ self.pertubed_model, batch_opt_adv = self.fabric.setup(
107
+ self.pertubed_model, batch_opt_adv
108
+ )
109
+
110
+ merge_model.requires_grad_(False)
111
+ pretrained_model.requires_grad_(False)
112
+
113
+ mask_model.train()
114
+ optimizer.zero_grad()
115
+
116
+ self.pertubed_model.train()
117
+ batch_opt_adv.zero_grad()
118
+ # torch.autograd.set_detect_anomaly(True)
119
+
120
+ pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
121
+ for step_idx in (
122
+ pbar := tqdm(
123
+ range(self.config.max_steps if not self.is_debug_mode else 5),
124
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
125
+ + "clip_post_defense_AWM",
126
+ dynamic_ncols=True,
127
+ disable=not self.fabric.is_global_zero,
128
+ )
129
+ ):
130
+ metrics = {}
131
+ # sample a shared mask and merge weights
132
+ with self.profile("sample mask"):
133
+ mask = mask_model.sample_mask(
134
+ mask_type="continuous", temperature=config.temperature
135
+ )
136
+ metrics["train/sparsity"] = mask_sparsity(mask)
137
+ with self.profile("merge weights"):
138
+ # rescale mask
139
+ for name, m in mask.items():
140
+ mask[name] = m / torch.mean(m)
141
+
142
+ merged_state_dict = merge_model.state_dict(keep_vars=True)
143
+ for name, parameter in merged_state_dict.items():
144
+ ## (1) mask--directly prune the merged model, the initial logits should be larger than 3
145
+ # merged_state_dict[name] = merged_state_dict[name]* mask[name]
146
+ ### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
147
+ merged_state_dict[name] = (
148
+ merged_state_dict[name] - pretrained_model_dict[name]
149
+ ) * mask[name] + pretrained_model_dict[name]
150
+
151
+ # ------ noise optimization based on the merging model ------
152
+ # detach merged state_dict
153
+ detached_merged_state_dict = {
154
+ k: p.detach() for k, p in merged_state_dict.items()
155
+ }
156
+ merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
157
+ merge_model, detached_merged_state_dict, args=args, kwargs=kwargs
158
+ )
159
+
160
+ total_loss = None
161
+ for task_idx, task in enumerate(
162
+ [c["name"] for c in self.modelpool.config.tta_datasets]
163
+ ):
164
+ with self.profile("data loading"):
165
+ batch = next(
166
+ self.get_shuffled_test_loader_iter(task)
167
+ ) # image ,label
168
+ # NOTE: The labels are not allowed to be used during test-time adaptation
169
+ # batch[0],batch[1] = batch[0].to(merge_model.device),batch[1].to(merge_model.device)
170
+ images = batch[0]
171
+ perturbed_images = (
172
+ images + self.pertubed_model.perturbed_input[task_idx]
173
+ )
174
+ combined_images = torch.cat((images, perturbed_images), dim=0)
175
+
176
+ with self.profile("forward pass"):
177
+ combined_logits = self.compute_logits(
178
+ merge_model_forward, combined_images, task
179
+ )
180
+ # print(combined_logits.size())
181
+ num_image = images.size(0)
182
+ logits, logits_adv = (
183
+ combined_logits[:num_image],
184
+ combined_logits[num_image:],
185
+ )
186
+ ori_label = torch.argmax(logits, dim=1).long()
187
+
188
+ loss = torch.mean(
189
+ -F.cross_entropy(logits_adv, ori_label, reduction="mean")
190
+ )
191
+ # print(loss)
192
+ total_loss = loss if total_loss is None else total_loss + loss
193
+
194
+ with self.profile("compute grad"):
195
+ self.fabric.backward(total_loss)
196
+
197
+ with self.profile("batch_opt_adv optimizer step"):
198
+ batch_opt_adv.step()
199
+ batch_opt_adv.zero_grad()
200
+
201
+ # ------ inner optimization goes here ------
202
+ # NOTE:
203
+ # Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
204
+ # set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
205
+ # -----------------------------------------
206
+ ### mask optimization
207
+ merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
208
+ merge_model, merged_state_dict, args=args, kwargs=kwargs
209
+ )
210
+ total_loss = None
211
+
212
+ # trigger_norm = self.config.trigger_norm
213
+ # pert = batch_pert * min(1, trigger_norm / torch.sum(torch.abs(batch_pert)))
214
+ # pert = pert.detach()
215
+
216
+ for task_idx, task in enumerate(
217
+ [c["name"] for c in self.modelpool.config.tta_datasets]
218
+ ):
219
+ with self.profile("data loading"), torch.no_grad():
220
+ batch = next(self.get_shuffled_test_loader_iter(task))
221
+ # NOTE: The labels are not allowed to be used during test-time adaptation
222
+ images = batch[0]
223
+
224
+ # perturbed_images = images + self.pertubed_model.perturbed_input[task_idx]
225
+ # perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
226
+
227
+ perturbed_images = torch.clamp(
228
+ images + self.pertubed_model.perturbed_input[task_idx],
229
+ min=0,
230
+ max=1,
231
+ )
232
+ combined_images = torch.cat((images, perturbed_images), dim=0)
233
+
234
+ with self.profile("forward pass"):
235
+ combined_logits = self.compute_logits(
236
+ merge_model_forward, combined_images, task
237
+ )
238
+ num_image = images.size(0)
239
+ logits, logits_adv = (
240
+ combined_logits[:num_image],
241
+ combined_logits[num_image:],
242
+ )
243
+
244
+ loss_nat = entropy_loss(logits)
245
+
246
+ # ### regu1
247
+ # ori_label = torch.argmax(logits, dim=1).long()
248
+ # loss_regu = -torch.mean(
249
+ # F.cross_entropy(logits_adv, ori_label, reduction="mean")
250
+ # )
251
+
252
+ ### regu2
253
+ loss_regu = entropy_loss(logits_adv)
254
+
255
+ loss = loss_nat + self.config.adv_weight * loss_regu
256
+ total_loss = loss if total_loss is None else total_loss + loss
257
+
258
+ # loss = entropy_loss(logits)
259
+ # total_loss = loss if total_loss is None else total_loss + loss
260
+
261
+ with self.profile("compute grad"):
262
+ self.fabric.backward(total_loss)
263
+
264
+ with self.profile("optimizer step"):
265
+ optimizer.step()
266
+ optimizer.zero_grad()
267
+
268
+ if lr_scheduler is not None:
269
+ lr_scheduler.step()
270
+ # metrics.update({"train/loss": loss.item()})
271
+ metrics.update(
272
+ {
273
+ "train/loss": loss.item(),
274
+ "loss_nat": loss_nat.item(),
275
+ "loss_regu": loss_regu.item(),
276
+ }
277
+ )
278
+ self.fabric.log_dict(metrics, step=step_idx)
279
+ pbar.set_postfix(metrics)
280
+ self.print_profile_summary()
281
+
282
+ if (step_idx + 1) % self.config.save_interval == 0:
283
+ with self.profiler.profile("save checkpoint"):
284
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
285
+ if not os.path.exists(save_dir):
286
+ os.makedirs(save_dir, exist_ok=True)
287
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
288
+ print(f"saving checkpoint to {save_path}")
289
+ state = {"model": mask_model}
290
+ self.fabric.save(save_path, state)
291
+
292
+ # Create or update a symbolic link to the latest checkpoint
293
+ if self.fabric.is_global_zero:
294
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
295
+ if os.path.exists(symlink_path):
296
+ os.remove(symlink_path)
297
+ os.link(os.path.abspath(save_path), symlink_path)
298
+
299
+ self.print_profile_summary()
300
+
301
+ def run(self, modelpool: HuggingFaceClipVisionPool):
302
+ self.modelpool = to_modelpool(modelpool)
303
+ config = self.config
304
+ self.log_hyperparams(config, filename="method_config.yaml")
305
+
306
+ with self.profile("setup models"):
307
+ merge_model, pretrained_model, mask_model = self.setup_models()
308
+ mask_model: MaskModel = self.fabric.to_device(mask_model)
309
+ merge_model = self.fabric.to_device(merge_model)
310
+ pretrained_model = self.fabric.to_device(pretrained_model)
311
+ self.pertubed_model = self.fabric.to_device(self.pertubed_model)
312
+ self.setup_zero_shot_classification_head(
313
+ task_names=[c["name"] for c in self.modelpool.config.tta_datasets]
314
+ )
315
+
316
+ if config.mask_checkpoint is None:
317
+ self.train_mask(
318
+ merge_model=merge_model,
319
+ pretrained_model=pretrained_model,
320
+ mask_model=mask_model,
321
+ )
322
+ else:
323
+ if self.fabric.is_global_zero:
324
+ print("loading mask from checkpoint", config.mask_checkpoint)
325
+ self.fabric.load(config.mask_checkpoint, {"model": mask_model})
326
+
327
+ with torch.no_grad():
328
+ if torch.cuda.is_available():
329
+ torch.cuda.empty_cache()
330
+ mask = mask_model.sample_mask(
331
+ mask_type=config.eval_mask_type,
332
+ temperature=config.temperature,
333
+ )
334
+ # rescale mask
335
+ for name, m in mask.items():
336
+ mask[name] = m / torch.mean(m)
337
+ pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
338
+ merged_state_dict = merge_model.state_dict(keep_vars=True)
339
+ for name, parameter in merged_state_dict.items():
340
+ ## (1) mask--directly prune the merged model, the initial logits should be larger than 3
341
+ # merged_state_dict[name] = merged_state_dict[name]* mask[name]
342
+ ### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
343
+ merged_state_dict[name] = (
344
+ merged_state_dict[name] - pretrained_model_dict[name]
345
+ ) * mask[name] + pretrained_model_dict[name]
346
+ merge_model.load_state_dict(merged_state_dict)
347
+ return merge_model
348
+
349
+
350
+ class PostDefenseSAUAlgorithmForCLIP(
351
+ CLIPClassificationMixin,
352
+ SimpleProfilerMixin,
353
+ ModelFusionAlgorithm,
354
+ ):
355
+ @torch.no_grad()
356
+ def setup_models(self):
357
+ config = self.config
358
+ self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
359
+ modelpool = self.modelpool
360
+
361
+ # Load the pretrained model
362
+ pretrained_model = modelpool.load_model("_pretrained_")
363
+ merge_model = modelpool.load_model("merge")
364
+ merge_model_ref = modelpool.load_model("merge")
365
+
366
+ # construct PGE mask model
367
+ mask_model = MaskModel(
368
+ merge_model,
369
+ ignore_untrained_params=True,
370
+ parameter_type="logits",
371
+ )
372
+ if self.merge_dtype is not None:
373
+ mask_model.to(self.merge_dtype)
374
+ mask_model.fill_(self.config.initial_logits)
375
+ # TODO: ablation study for the initialization of mask model
376
+ # for param in mask_model.parameters():
377
+ # param.data = param + 0.1 * torch.randn_like(param)
378
+ print("Summary of mask model:")
379
+ print_parameters(mask_model)
380
+
381
+ self.pertubed_model = nn.Module()
382
+ self.pertubed_model.perturbed_input = nn.Parameter(
383
+ torch.zeros([len(self.modelpool.config.tta_datasets), 3, 224, 224]),
384
+ requires_grad=True,
385
+ )
386
+
387
+ return merge_model, merge_model_ref, pretrained_model, mask_model
388
+
389
+ def train_mask(
390
+ self, merge_model, merge_model_ref, pretrained_model, mask_model: MaskModel
391
+ ):
392
+ config = self.config
393
+
394
+ # configure optimizer
395
+ lr_scheduler = None
396
+ if self.config.optimizer == "adam":
397
+ optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
398
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
399
+
400
+ batch_opt_adv = torch.optim.Adam(
401
+ params=self.pertubed_model.parameters(), lr=self.config.adv_lr
402
+ )
403
+ self.pertubed_model, batch_opt_adv = self.fabric.setup(
404
+ self.pertubed_model, batch_opt_adv
405
+ )
406
+
407
+ merge_model.requires_grad_(False)
408
+ merge_model_ref.requires_grad_(False)
409
+ pretrained_model.requires_grad_(False)
410
+
411
+ mask_model.train()
412
+ optimizer.zero_grad()
413
+
414
+ self.pertubed_model.train()
415
+ batch_opt_adv.zero_grad()
416
+ # torch.autograd.set_detect_anomaly(True)
417
+
418
+ pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
419
+ for step_idx in (
420
+ pbar := tqdm(
421
+ range(self.config.max_steps if not self.is_debug_mode else 5),
422
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
423
+ + "clip_post_defense_SAU",
424
+ dynamic_ncols=True,
425
+ disable=not self.fabric.is_global_zero,
426
+ )
427
+ ):
428
+ metrics = {}
429
+ # sample a shared mask and merge weights
430
+ with self.profile("sample mask"):
431
+ mask = mask_model.sample_mask(
432
+ mask_type="continuous", temperature=config.temperature
433
+ )
434
+ metrics["train/sparsity"] = mask_sparsity(mask)
435
+ with self.profile("merge weights"):
436
+ # rescale mask
437
+ for name, m in mask.items():
438
+ mask[name] = m / torch.mean(m)
439
+
440
+ merged_state_dict = merge_model.state_dict(keep_vars=True)
441
+ for name, parameter in merged_state_dict.items():
442
+ ## (1) directy mask/prune the merged model, the initial logits should be larger than 3, without decreasing the acc
443
+ # merged_state_dict[name] = merged_state_dict[name]* mask[name]
444
+ ### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
445
+ merged_state_dict[name] = (
446
+ merged_state_dict[name] - pretrained_model_dict[name]
447
+ ) * mask[name] + pretrained_model_dict[name]
448
+
449
+ # ------ noise optimization based on the merging model ------
450
+ # detach merged state_dict
451
+ detached_merged_state_dict = {
452
+ k: p.detach() for k, p in merged_state_dict.items()
453
+ }
454
+ merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
455
+ merge_model, detached_merged_state_dict, args=args, kwargs=kwargs
456
+ )
457
+
458
+ total_loss = None
459
+ for task_idx, task in enumerate(
460
+ [c["name"] for c in self.modelpool.config.tta_datasets]
461
+ ):
462
+ with self.profile("data loading"):
463
+ batch = next(
464
+ self.get_shuffled_test_loader_iter(task)
465
+ ) # image ,label
466
+ # NOTE: The labels are not allowed to be used during test-time adaptation
467
+ # batch[0],batch[1] = batch[0].to(merge_model.device),batch[1].to(merge_model.device)
468
+ images = batch[0]
469
+ perturbed_images = (
470
+ images + self.pertubed_model.perturbed_input[task_idx]
471
+ )
472
+ combined_images = torch.cat((images, perturbed_images), dim=0)
473
+
474
+ with self.profile("forward pass"):
475
+ num_image = images.size(0)
476
+
477
+ combined_logits = self.compute_logits(
478
+ merge_model_forward, combined_images, task
479
+ )
480
+ logits, logits_adv = (
481
+ combined_logits[:num_image],
482
+ combined_logits[num_image:],
483
+ )
484
+ ori_label = torch.argmax(logits, dim=1).long()
485
+ pert_label = torch.argmax(logits_adv, dim=1).long()
486
+
487
+ combined_logits_ref = self.compute_logits(
488
+ merge_model_ref, combined_images, task
489
+ )
490
+ logits_ref, logits_adv_ref = (
491
+ combined_logits_ref[:num_image],
492
+ combined_logits_ref[num_image:],
493
+ )
494
+ ori_label_ref = torch.argmax(logits_ref, dim=1).long()
495
+ pert_label_ref = torch.argmax(logits_adv_ref, dim=1).long()
496
+
497
+ success_attack = pert_label != ori_label
498
+ success_attack_ref = pert_label_ref != ori_label_ref
499
+ common_attack = torch.logical_and(
500
+ success_attack, success_attack_ref
501
+ )
502
+ shared_attack = torch.logical_and(
503
+ common_attack, pert_label == pert_label_ref
504
+ )
505
+
506
+ # Shared loss
507
+ # JS divergence version (https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence)
508
+ p_model = F.softmax(logits_adv, dim=1).clamp(min=1e-8)
509
+ p_ref = F.softmax(logits_adv_ref, dim=1).clamp(min=1e-8)
510
+ mix_p = 0.5 * (p_model + p_ref)
511
+ loss_js = 0.5 * (
512
+ p_model * p_model.log() + p_ref * p_ref.log()
513
+ ) - 0.5 * (p_model * mix_p.log() + p_ref * mix_p.log())
514
+ loss_cross = (
515
+ loss_js[torch.logical_not(shared_attack)].sum(dim=1).sum()
516
+ / images.shape[0]
517
+ )
518
+
519
+ ### maximization perturbation loss
520
+ ## using the test data without the true label
521
+ loss_adv = torch.mean(
522
+ -F.cross_entropy(logits_adv, ori_label, reduction="mean")
523
+ )
524
+
525
+ loss = self.config.beta1 * loss_adv + self.config.beta2 * loss_cross
526
+
527
+ total_loss = loss if total_loss is None else total_loss + loss
528
+
529
+ with self.profile("compute grad"):
530
+ self.fabric.backward(total_loss)
531
+
532
+ with self.profile("batch_opt_adv optimizer step"):
533
+ batch_opt_adv.step()
534
+ batch_opt_adv.zero_grad()
535
+
536
+ # ------ inner optimization goes here ------
537
+ # NOTE:
538
+ # Because the algorithmic parameters of task arithmetic are assumed to be chosen on a validation test
539
+ # set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
540
+ # -----------------------------------------
541
+ ### mask optimization
542
+ merge_model_forward = lambda *args, **kwargs: torch.func.functional_call(
543
+ merge_model, merged_state_dict, args=args, kwargs=kwargs
544
+ )
545
+ total_loss = None
546
+
547
+ # trigger_norm = self.config.trigger_norm
548
+ # pert = batch_pert * min(1, trigger_norm / torch.sum(torch.abs(batch_pert)))
549
+ # pert = pert.detach()
550
+
551
+ for task_idx, task in enumerate(
552
+ [c["name"] for c in self.modelpool.config.tta_datasets]
553
+ ):
554
+ classnames, templates = get_classnames_and_templates(
555
+ self.modelpool.get_train_dataset_config(task)["dataset"].name
556
+ )
557
+ num_classes = len(classnames)
558
+
559
+ with self.profile("data loading"), torch.no_grad():
560
+ batch = next(self.get_shuffled_test_loader_iter(task))
561
+ # NOTE: The labels are not allowed to be used during test-time adaptation
562
+ images = batch[0]
563
+
564
+ perturbed_images = torch.clamp(
565
+ images + self.pertubed_model.perturbed_input[task_idx],
566
+ min=0,
567
+ max=1,
568
+ )
569
+ combined_images = torch.cat((images, perturbed_images), dim=0)
570
+
571
+ with self.profile("forward pass"):
572
+
573
+ num_image = images.size(0)
574
+
575
+ ### loss_nat
576
+ combined_logits = self.compute_logits(
577
+ merge_model_forward, combined_images, task
578
+ )
579
+ logits, logits_adv = (
580
+ combined_logits[:num_image],
581
+ combined_logits[num_image:],
582
+ )
583
+ ori_label = torch.argmax(logits, dim=1).long()
584
+ pert_label = torch.argmax(logits_adv, dim=1).long()
585
+ loss_nat = entropy_loss(logits)
586
+
587
+ ########### loss_regu from noise
588
+ ### regu1
589
+ # ori_label = torch.argmax(logits, dim=1).long()
590
+ # loss_regu = -torch.mean(
591
+ # F.cross_entropy(logits_adv, ori_label, reduction="mean")
592
+ # )
593
+ ### regu2
594
+ loss_regu = entropy_loss(logits_adv)
595
+
596
+ ### loss shared
597
+ combined_logits_ref = self.compute_logits(
598
+ merge_model_ref, combined_images, task
599
+ )
600
+ logits_ref, logits_adv_ref = (
601
+ combined_logits_ref[:num_image],
602
+ combined_logits_ref[num_image:],
603
+ )
604
+ ori_label_ref = torch.argmax(logits_ref, dim=1).long()
605
+ pert_label_ref = torch.argmax(logits_adv_ref, dim=1).long()
606
+
607
+ success_attack = pert_label != ori_label
608
+
609
+ #### due to fact that we only use the test data without true label there, we replace the true label with ori_label
610
+ success_attack_ref = pert_label_ref != ori_label
611
+ success_attack_ref = success_attack_ref & (
612
+ pert_label_ref != ori_label_ref
613
+ )
614
+
615
+ common_attack = torch.logical_and(
616
+ success_attack, success_attack_ref
617
+ )
618
+ shared_attack = torch.logical_and(
619
+ common_attack, pert_label == pert_label_ref
620
+ )
621
+
622
+ potential_poison = success_attack_ref
623
+ if potential_poison.sum() == 0:
624
+ loss_shared = torch.tensor(0.0).to(merge_model.device)
625
+ else:
626
+ one_hot = F.one_hot(pert_label_ref, num_classes=num_classes)
627
+
628
+ neg_one_hot = 1 - one_hot
629
+ neg_p = (F.softmax(logits_adv, dim=1) * neg_one_hot).sum(dim=1)[
630
+ potential_poison
631
+ ]
632
+ pos_p = (F.softmax(logits_adv, dim=1) * one_hot).sum(dim=1)[
633
+ potential_poison
634
+ ]
635
+
636
+ # clamp the too small values to avoid nan and discard samples with p<1% to be shared
637
+ # Note: The below equation combine two identical terms in math. Although they are the same in math, they are different in implementation due to the numerical issue.
638
+ # Combining them can reduce the numerical issue.
639
+
640
+ loss_shared = (
641
+ -torch.sum(torch.log(1e-6 + neg_p.clamp(max=0.999)))
642
+ - torch.sum(torch.log(1 + 1e-6 - pos_p.clamp(min=0.001)))
643
+ ) / 2
644
+ loss_shared = loss_shared / images.shape[0]
645
+
646
+ loss = (
647
+ loss_nat
648
+ + self.config.adv_weight * loss_regu
649
+ + self.config.shared_weight * loss_shared
650
+ )
651
+ total_loss = loss if total_loss is None else total_loss + loss
652
+
653
+ with self.profile("compute grad"):
654
+ self.fabric.backward(total_loss)
655
+
656
+ with self.profile("optimizer step"):
657
+ optimizer.step()
658
+ optimizer.zero_grad()
659
+
660
+ if lr_scheduler is not None:
661
+ lr_scheduler.step()
662
+ # metrics.update({"train/loss": loss.item()})
663
+ metrics.update(
664
+ {
665
+ "train/loss": loss.item(),
666
+ "loss_nat": loss_nat.item(),
667
+ "loss_regu": loss_regu.item(),
668
+ "loss_shared": loss_shared.item(),
669
+ }
670
+ )
671
+ self.fabric.log_dict(metrics, step=step_idx)
672
+ pbar.set_postfix(metrics)
673
+ self.print_profile_summary()
674
+
675
+ if (step_idx + 1) % self.config.save_interval == 0:
676
+ with self.profiler.profile("save checkpoint"):
677
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
678
+ if not os.path.exists(save_dir):
679
+ os.makedirs(save_dir, exist_ok=True)
680
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
681
+ print(f"saving checkpoint to {save_path}")
682
+ state = {"model": mask_model}
683
+ self.fabric.save(save_path, state)
684
+
685
+ # Create or update a symbolic link to the latest checkpoint
686
+ if self.fabric.is_global_zero:
687
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
688
+ if os.path.exists(symlink_path):
689
+ os.remove(symlink_path)
690
+ os.link(os.path.abspath(save_path), symlink_path)
691
+
692
+ self.print_profile_summary()
693
+
694
+ def run(self, modelpool: HuggingFaceClipVisionPool):
695
+ self.modelpool = to_modelpool(modelpool)
696
+ config = self.config
697
+ self.log_hyperparams(config, filename="method_config.yaml")
698
+
699
+ with self.profile("setup models"):
700
+ merge_model, merge_model_ref, pretrained_model, mask_model = (
701
+ self.setup_models()
702
+ )
703
+ mask_model: MaskModel = self.fabric.to_device(mask_model)
704
+ merge_model = self.fabric.to_device(merge_model)
705
+ merge_model_ref = self.fabric.to_device(merge_model_ref)
706
+ pretrained_model = self.fabric.to_device(pretrained_model)
707
+ self.pertubed_model = self.fabric.to_device(self.pertubed_model)
708
+ self.setup_zero_shot_classification_head(
709
+ task_names=[c["name"] for c in self.modelpool.config.tta_datasets]
710
+ )
711
+
712
+ if config.mask_checkpoint is None:
713
+ self.train_mask(
714
+ merge_model=merge_model,
715
+ merge_model_ref=merge_model_ref,
716
+ pretrained_model=pretrained_model,
717
+ mask_model=mask_model,
718
+ )
719
+ else:
720
+ if self.fabric.is_global_zero:
721
+ print("loading mask from checkpoint", config.mask_checkpoint)
722
+ self.fabric.load(config.mask_checkpoint, {"model": mask_model})
723
+
724
+ with torch.no_grad():
725
+ if torch.cuda.is_available():
726
+ torch.cuda.empty_cache()
727
+ mask = mask_model.sample_mask(
728
+ mask_type=config.eval_mask_type,
729
+ temperature=config.temperature,
730
+ )
731
+ # rescale mask
732
+ for name, m in mask.items():
733
+ mask[name] = m / torch.mean(m)
734
+ pretrained_model_dict = pretrained_model.state_dict(keep_vars=True)
735
+ merged_state_dict = merge_model.state_dict(keep_vars=True)
736
+ for name, parameter in merged_state_dict.items():
737
+ ## (1) mask--directly prune the merged model, the initial logits should be larger than 3
738
+ # merged_state_dict[name] = merged_state_dict[name]* mask[name]
739
+ ### (2) mask the task vector, similar to concrete mask, the initial logits can be set as 0
740
+ merged_state_dict[name] = (
741
+ merged_state_dict[name] - pretrained_model_dict[name]
742
+ ) * mask[name] + pretrained_model_dict[name]
743
+ merge_model.load_state_dict(merged_state_dict)
744
+ return merge_model