fusion-bench 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. fusion_bench/compat/method/__init__.py +5 -0
  2. fusion_bench/dataset/fer2013.py +1 -0
  3. fusion_bench/method/DOGE_TA/DOGE_TA.py +364 -0
  4. fusion_bench/method/DOGE_TA/__init__.py +2 -0
  5. fusion_bench/method/DOGE_TA/clip_layer_wise_adamerging.py +46 -0
  6. fusion_bench/method/DOGE_TA/layer_wise_adamerging.py +250 -0
  7. fusion_bench/method/__init__.py +22 -0
  8. fusion_bench/method/classification/continual_clip_finetune.py +1 -1
  9. fusion_bench/method/concrete_subspace/__init__.py +8 -0
  10. fusion_bench/method/concrete_subspace/clip_post_defense.py +744 -0
  11. fusion_bench/method/concrete_subspace/clip_safe_concrete_adamerging.py +832 -0
  12. fusion_bench/method/isotropic_merging/__init__.py +15 -0
  13. fusion_bench/method/isotropic_merging/iso.py +114 -0
  14. fusion_bench/method/isotropic_merging/iso_utils.py +176 -0
  15. fusion_bench/method/task_singular_vector/TSVM.py +22 -2
  16. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +531 -0
  17. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/METADATA +1 -1
  18. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/RECORD +30 -13
  19. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/WHEEL +1 -1
  20. fusion_bench_config/method/DOGE_TA/DOGE_TA.yaml +4 -0
  21. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +38 -0
  22. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +41 -0
  23. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +39 -0
  24. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +40 -0
  25. fusion_bench_config/method/isotropic_merging/iso_c.yaml +4 -0
  26. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +5 -0
  27. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +6 -0
  28. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/LICENSE +0 -0
  29. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/entry_points.txt +0 -0
  30. {fusion_bench-0.2.9.dist-info → fusion_bench-0.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,832 @@
1
+ """
2
+ Defense-Aware Task-wise & Layer-wise Concrete AdaMerging for CLIP ViT models
3
+
4
+ Examples:
5
+
6
+ ```bash
7
+ fusion_bench \
8
+ fabric_logger.name= \
9
+ method=clip_safe_concrete_task_wise_adamerging \
10
+ modelpool= \
11
+ taskpool=
12
+ ```
13
+
14
+ ```bash
15
+ fusion_bench \
16
+ fabric_logger.name= \
17
+ method=clip_safe_concrete_layer_wise_adamerging \
18
+ modelpool= \
19
+ taskpool=
20
+ ```
21
+ """
22
+
23
+ import logging
24
+ import os
25
+ from copy import deepcopy
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from tqdm.autonotebook import tqdm
31
+
32
+ from fusion_bench.compat.method import ModelFusionAlgorithm
33
+ from fusion_bench.compat.modelpool import to_modelpool
34
+ from fusion_bench.compat.modelpool.huggingface_clip_vision import (
35
+ HuggingFaceClipVisionPool,
36
+ )
37
+ from fusion_bench.method.adamerging.entropy_loss import entropy_loss
38
+ from fusion_bench.mixins import CLIPClassificationMixin
39
+ from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
40
+ from fusion_bench.models.masks import MaskModel, mask_sparsity
41
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
42
+ LayerWiseMergedModel,
43
+ get_layer_wise_weights,
44
+ )
45
+ from fusion_bench.models.wrappers.task_wise_fusion import (
46
+ TaskWiseMergedModel,
47
+ get_task_wise_weights,
48
+ )
49
+ from fusion_bench.utils.dtype import parse_dtype
50
+ from fusion_bench.utils.parameters import print_parameters
51
+
52
+ log = logging.getLogger(__name__)
53
+
54
+
55
+ class ConcreteSafeTaskWiseAdaMergingForCLIP(
56
+ CLIPClassificationMixin,
57
+ SimpleProfilerMixin,
58
+ ModelFusionAlgorithm,
59
+ ):
60
+ @torch.no_grad()
61
+ def setup_models(self):
62
+ config = self.config
63
+ self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
64
+ modelpool = self.modelpool
65
+ # Load the pretrained model
66
+ pretrained_model = modelpool.load_model("_pretrained_")
67
+
68
+ # construct PGE mask model
69
+ mask_model = MaskModel(
70
+ pretrained_model,
71
+ ignore_untrained_params=True,
72
+ parameter_type="logits",
73
+ )
74
+ if self.merge_dtype is not None:
75
+ mask_model.to(self.merge_dtype)
76
+ mask_model.fill_(self.config.initial_logits)
77
+ # TODO: ablation study for the initialization of mask model
78
+ # for param in mask_model.parameters():
79
+ # param.data = param + 0.1 * torch.randn_like(param)
80
+ print("Summary of mask model:")
81
+ print_parameters(mask_model)
82
+
83
+ # Load the fine-tuned models
84
+ finetuned_models = [
85
+ modelpool.load_model(name) for name in modelpool.model_names
86
+ ]
87
+
88
+ task_wise_weight = get_task_wise_weights(
89
+ num_models=len(modelpool.model_names),
90
+ init_values=self.config.scaling_factor,
91
+ )
92
+ self.init_task_wise_weight = deepcopy(task_wise_weight)
93
+
94
+ # create a warpped model
95
+ module = TaskWiseMergedModel(
96
+ task_wise_weight=task_wise_weight,
97
+ pretrained_model=pretrained_model,
98
+ finetuned_models=finetuned_models,
99
+ clamp_weights=self.config.clamp_weights,
100
+ tie_weights=self.config.tie_weights,
101
+ strict=self.config.strict,
102
+ task_vector_dtype=self.merge_dtype,
103
+ )
104
+
105
+ self.pertubed_model = nn.Module()
106
+ self.pertubed_model.perturbed_input = nn.Parameter(
107
+ torch.zeros([len(modelpool.model_names), 3, 224, 224]), requires_grad=True
108
+ )
109
+ return module, mask_model
110
+
111
+ def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
112
+ config = self.config
113
+ self.init_task_wise_weight = self.to_device(self.init_task_wise_weight)
114
+
115
+ # configure optimizer
116
+ lr_scheduler = None
117
+ if self.config.optimizer == "adam":
118
+
119
+ ### for merge_weight
120
+ base_optimizer = torch.optim.Adam(
121
+ [module.merge_weight], lr=self.config.base_lr
122
+ )
123
+ module, base_optimizer = self.fabric.setup(module, base_optimizer)
124
+
125
+ ### for mask
126
+ optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
127
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
128
+
129
+ ### for perturbed noise
130
+ batch_opt_adv = torch.optim.Adam(
131
+ params=self.pertubed_model.parameters(), lr=self.config.adv_lr
132
+ )
133
+ self.pertubed_model, batch_opt_adv = self.fabric.setup(
134
+ self.pertubed_model, batch_opt_adv
135
+ )
136
+ else:
137
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
138
+
139
+ module.train()
140
+ mask_model.train()
141
+ self.pertubed_model.train()
142
+ for step_idx in (
143
+ pbar := tqdm(
144
+ range(self.config.max_steps if not self.is_debug_mode else 5),
145
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
146
+ + "Concrete Safe AdaMerging Meta-Learn Mask (1/2)",
147
+ dynamic_ncols=True,
148
+ disable=not self.fabric.is_global_zero,
149
+ )
150
+ ):
151
+ metrics = {}
152
+ # sample a shared mask and merge weights
153
+ with self.profile("sample mask"):
154
+ mask = mask_model.sample_mask(
155
+ mask_type="continuous", temperature=config.temperature
156
+ )
157
+ metrics["train/sparsity"] = mask_sparsity(mask)
158
+ with self.profile("merge weights"):
159
+ # rescale mask
160
+ for name, m in mask.items():
161
+ mask[name] = m / torch.mean(m)
162
+
163
+ # for inner optimization, we do not optimize the mask, so we detach it
164
+ module.merge_weights(
165
+ task_vector_mask={name: m.detach() for name, m in mask.items()}
166
+ )
167
+
168
+ # ------ inner optimization goes here ------
169
+
170
+ ### (1)optimize the merging weight
171
+ module.merge_weight.data = deepcopy(self.init_task_wise_weight)
172
+ total_loss = None
173
+ for task in self.modelpool.model_names:
174
+ with self.profile("data loading"):
175
+ batch = next(self.get_shuffled_test_loader_iter(task))
176
+ # NOTE: The labels are not allowed to be used during test-time adaptation
177
+ images = batch[0]
178
+
179
+ with self.profile("forward pass"):
180
+ logits = self.compute_logits(module, images, task)
181
+ loss = entropy_loss(logits)
182
+ total_loss = loss if total_loss is None else total_loss + loss
183
+
184
+ with self.profile("compute grad"):
185
+ self.fabric.backward(total_loss)
186
+
187
+ with self.profile("base optimizer step"):
188
+ base_optimizer.step()
189
+ base_optimizer.zero_grad()
190
+
191
+ with self.profile("merge weights"):
192
+ module.merge_weights(task_vector_mask=mask)
193
+
194
+ # (2)noise optimization based on the merging model
195
+
196
+ # detach merged state_dict
197
+ merged_state_dict = module._merged_state_dict
198
+ detached_merged_state_dict = {
199
+ k: p.detach() for k, p in merged_state_dict.items()
200
+ }
201
+ module._merged_state_dict = detached_merged_state_dict
202
+
203
+ total_loss = None
204
+ for task_idx, task in enumerate(self.modelpool.model_names):
205
+ with self.profile("data loading"):
206
+ batch = next(self.get_shuffled_test_loader_iter(task))
207
+ # NOTE: The labels are not allowed to be used during test-time adaptation
208
+ images = batch[0]
209
+ perturbed_images = (
210
+ images + self.pertubed_model.perturbed_input[task_idx]
211
+ )
212
+ combined_images = torch.cat((images, perturbed_images), dim=0)
213
+
214
+ with self.profile("forward pass"):
215
+ combined_logits = self.compute_logits(module, combined_images, task)
216
+ logits = combined_logits[: images.size(0)]
217
+ logits_adv = combined_logits[images.size(0) :]
218
+ ori_label = torch.argmax(logits, axis=1).long()
219
+ loss = torch.mean(
220
+ -F.cross_entropy(logits_adv, ori_label, reduction="mean")
221
+ )
222
+ total_loss = loss if total_loss is None else total_loss + loss
223
+
224
+ with self.profile("compute grad"):
225
+ self.fabric.backward(total_loss)
226
+
227
+ with self.profile("batch_opt_adv optimizer step"):
228
+ batch_opt_adv.step()
229
+ batch_opt_adv.zero_grad()
230
+
231
+ # (3)mask optimization
232
+ total_loss = None
233
+ module._merged_state_dict = merged_state_dict
234
+
235
+ for task_idx, task in enumerate(self.modelpool.model_names):
236
+ with self.profile("data loading"), torch.no_grad():
237
+ batch = next(self.get_shuffled_test_loader_iter(task))
238
+ # NOTE: The labels are not allowed to be used during test-time adaptation
239
+ images = batch[0]
240
+ perturbed_images = (
241
+ images + self.pertubed_model.perturbed_input[task_idx]
242
+ )
243
+ perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
244
+ combined_images = torch.cat((images, perturbed_images), dim=0)
245
+
246
+ with self.profile("forward pass"):
247
+ combined_logits = self.compute_logits(module, combined_images, task)
248
+ logits = combined_logits[: images.size(0)]
249
+ logits_adv = combined_logits[images.size(0) :]
250
+
251
+ # # ### regu1
252
+ # ori_label = torch.argmax(logits, axis=1).long()
253
+ # loss_nat = entropy_loss(logits)
254
+ # loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
255
+
256
+ ### regu2
257
+ loss_regu = entropy_loss(logits_adv)
258
+ loss_nat = entropy_loss(logits)
259
+
260
+ loss = loss_nat + self.config.adv_weight * loss_regu
261
+ total_loss = loss if total_loss is None else total_loss + loss
262
+
263
+ with self.profile("compute grad"):
264
+ self.fabric.backward(total_loss)
265
+
266
+ with self.profile("optimizer step"):
267
+ optimizer.step()
268
+ optimizer.zero_grad()
269
+
270
+ if lr_scheduler is not None:
271
+ lr_scheduler.step()
272
+
273
+ # metrics.update({"train/loss": loss.item()})
274
+ metrics.update(
275
+ {
276
+ "train/loss": loss.item(),
277
+ "loss_nat": loss_nat.item(),
278
+ "loss_regu": loss_regu.item(),
279
+ }
280
+ )
281
+ self.fabric.log_dict(metrics, step=step_idx)
282
+ pbar.set_postfix(metrics)
283
+ self.print_profile_summary()
284
+
285
+ if (step_idx + 1) % self.config.save_interval == 0:
286
+ with self.profiler.profile("save checkpoint"):
287
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
288
+ if not os.path.exists(save_dir):
289
+ os.makedirs(save_dir, exist_ok=True)
290
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
291
+ print(f"saving checkpoint to {save_path}")
292
+ state = {"model": mask_model}
293
+ self.fabric.save(save_path, state)
294
+
295
+ # Create or update a symbolic link to the latest checkpoint
296
+ if self.fabric.is_global_zero:
297
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
298
+ if os.path.exists(symlink_path):
299
+ os.remove(symlink_path)
300
+ os.link(os.path.abspath(save_path), symlink_path)
301
+
302
+ self.print_profile_summary()
303
+
304
+ def run_adamerging(self, module: TaskWiseMergedModel, mask):
305
+ module.merge_weight.data = deepcopy(self.init_task_wise_weight)
306
+ base_optimizer = torch.optim.Adam(
307
+ [module.merge_weight], lr=self.config.adamerging_lr
308
+ )
309
+ module, base_optimizer = self.fabric.setup(module, base_optimizer)
310
+ module.train()
311
+ for step_idx in (
312
+ pbar := tqdm(
313
+ range(
314
+ self.config.max_adamerging_steps if not self.is_debug_mode else 5
315
+ ),
316
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
317
+ + "Concrete AdaMerging AdaMerging (2/2)",
318
+ dynamic_ncols=True,
319
+ disable=not self.fabric.is_global_zero,
320
+ )
321
+ ):
322
+ step_idx = step_idx + self.config.max_steps
323
+ with self.profile("merge weights"):
324
+ module.merge_weights(task_vector_mask=mask)
325
+
326
+ metrics = {}
327
+ total_loss = None
328
+ for task in self.modelpool.model_names:
329
+ with self.profile("data loading"):
330
+ batch = next(self.get_shuffled_test_loader_iter(task))
331
+ # NOTE: The labels are not allowed to be used during test-time adaptation
332
+ images = batch[0]
333
+ with self.profile("forward pass"):
334
+ logits = self.compute_logits(module, images, task)
335
+ loss = entropy_loss(logits)
336
+ total_loss = loss if total_loss is None else total_loss + loss
337
+
338
+ with self.profile("compute grad"):
339
+ self.fabric.backward(total_loss)
340
+
341
+ with self.profile("base optimizer step"):
342
+ base_optimizer.step()
343
+ base_optimizer.zero_grad()
344
+
345
+ metrics.update({"train/loss": loss.item()})
346
+ self.fabric.log_dict(metrics, step=step_idx)
347
+ pbar.set_postfix(metrics)
348
+
349
+ if (step_idx + 1) % self.config.save_interval == 0:
350
+ with self.profiler.profile("save checkpoint"):
351
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
352
+ if not os.path.exists(save_dir):
353
+ os.makedirs(save_dir, exist_ok=True)
354
+ save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
355
+ print(f"saving checkpoint to {save_path}")
356
+ state = {"merge_weight": module.merge_weight}
357
+ self.fabric.save(save_path, state)
358
+
359
+ # Create or update a symbolic link to the latest checkpoint
360
+ if self.fabric.is_global_zero:
361
+ symlink_path = os.path.join(
362
+ save_dir, "merge_weight_latest_checkpoint.pt"
363
+ )
364
+ if os.path.exists(symlink_path):
365
+ os.remove(symlink_path)
366
+ os.link(os.path.abspath(save_path), symlink_path)
367
+
368
+ self.print_profile_summary()
369
+ return module
370
+
371
+ def run(self, modelpool: HuggingFaceClipVisionPool):
372
+ self.modelpool = to_modelpool(modelpool)
373
+ config = self.config
374
+ self.log_hyperparams(config, filename="method_config.yaml")
375
+
376
+ with self.profile("setup models"):
377
+ module, mask_model = self.setup_models()
378
+ mask_model: MaskModel = self.fabric.to_device(mask_model)
379
+ module: TaskWiseMergedModel = self.fabric.to_device(module)
380
+ self.pertubed_model = self.fabric.to_device(self.pertubed_model)
381
+ self.setup_zero_shot_classification_head()
382
+
383
+ if config.mask_checkpoint is None:
384
+ self.train_mask(module=module, mask_model=mask_model)
385
+ else:
386
+ if self.fabric.is_global_zero:
387
+ print("loading mask from checkpoint", config.mask_checkpoint)
388
+ self.fabric.load(config.mask_checkpoint, {"model": mask_model})
389
+
390
+ # run adamerging
391
+ with torch.no_grad():
392
+ mask = mask_model.sample_mask(
393
+ mask_type=config.eval_mask_type,
394
+ temperature=config.temperature,
395
+ )
396
+ # rescale mask
397
+ for name, m in mask.items():
398
+ mask[name] = m / torch.mean(m)
399
+ module = self.run_adamerging(module, mask=mask)
400
+
401
+ with torch.no_grad():
402
+ model = module.merge_and_unload(mask)
403
+ return model
404
+
405
+
406
+ class ConcreteSafeLayerWiseAdaMergingForCLIP(
407
+ CLIPClassificationMixin,
408
+ SimpleProfilerMixin,
409
+ ModelFusionAlgorithm,
410
+ ):
411
+ @torch.no_grad()
412
+ def setup_models(self):
413
+ modelpool = self.modelpool
414
+ self.merge_dtype = parse_dtype(config.get("merge_dtype", None))
415
+ # Load the pretrained model
416
+ pretrained_model = modelpool.load_model("_pretrained_")
417
+
418
+ # construct PGE mask model
419
+ mask_model = MaskModel(
420
+ pretrained_model,
421
+ ignore_untrained_params=True,
422
+ parameter_type="logits",
423
+ )
424
+ if self.merge_dtype is not None:
425
+ mask_model.to(self.merge_dtype)
426
+ mask_model.fill_(self.config.initial_logits)
427
+ # TODO: ablation study for the initialization of mask model
428
+ # for param in mask_model.parameters():
429
+ # param.data = param + 0.1 * torch.randn_like(param)
430
+ print("Summary of mask model:")
431
+ print_parameters(mask_model)
432
+
433
+ # Load the fine-tuned models
434
+ finetuned_models = [
435
+ modelpool.load_model(name) for name in modelpool.model_names
436
+ ]
437
+
438
+ layer_wise_weight = get_layer_wise_weights(
439
+ num_models=len(modelpool.model_names),
440
+ num_layers=len(
441
+ tuple(filter(lambda p: p.requires_grad, pretrained_model.parameters()))
442
+ ),
443
+ init_values=self.config.scaling_factor,
444
+ )
445
+ self.init_layer_wise_weight = deepcopy(layer_wise_weight)
446
+
447
+ # create a warpped model
448
+ module = LayerWiseMergedModel(
449
+ layer_wise_weight=layer_wise_weight,
450
+ pretrained_model=pretrained_model,
451
+ finetuned_models=finetuned_models,
452
+ clamp_weights=self.config.clamp_weights,
453
+ tie_weights=self.config.tie_weights,
454
+ strict=self.config.strict,
455
+ layer_vector_dtype=self.merge_dtype,
456
+ )
457
+
458
+ self.pertubed_model = nn.Module()
459
+ self.pertubed_model.perturbed_input = nn.Parameter(
460
+ torch.zeros([len(modelpool.model_names), 3, 224, 224]), requires_grad=True
461
+ )
462
+ return module, mask_model
463
+
464
+ def train_mask(self, module: LayerWiseMergedModel, mask_model: MaskModel):
465
+ config = self.config
466
+ self.init_layer_wise_weight = self.to_device(self.init_layer_wise_weight)
467
+
468
+ # configure optimizer
469
+ lr_scheduler = None
470
+ if self.config.optimizer == "adam":
471
+ base_optimizer = torch.optim.Adam(
472
+ [module.merge_weight], lr=self.config.base_lr
473
+ )
474
+ optimizer = torch.optim.Adam(mask_model.parameters(), lr=self.config.lr)
475
+ print(f"{optimizer=}")
476
+ # TODO: ablation study for the learning rate scheduler. It should yield similar results.
477
+ # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
478
+ # optimizer, self.config.max_steps, eta_min=0.1
479
+ # )
480
+ module, base_optimizer = self.fabric.setup(module, base_optimizer)
481
+ mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
482
+
483
+ batch_opt_adv = torch.optim.Adam(
484
+ params=self.pertubed_model.parameters(), lr=self.config.adv_lr
485
+ )
486
+ self.pertubed_model, batch_opt_adv = self.fabric.setup(
487
+ self.pertubed_model, batch_opt_adv
488
+ )
489
+ else:
490
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
491
+
492
+ module.train()
493
+ mask_model.train()
494
+ # self.pertubed_model.train()
495
+ for step_idx in (
496
+ pbar := tqdm(
497
+ range(self.config.max_steps if not self.is_debug_mode else 5),
498
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
499
+ + "Concrete Safe AdaMerging Meta-Learn Mask (1/2)",
500
+ dynamic_ncols=True,
501
+ disable=not self.fabric.is_global_zero,
502
+ )
503
+ ):
504
+ metrics = {}
505
+ # sample a shared mask and merge weights
506
+ with self.profile("sample mask"):
507
+ mask = mask_model.sample_mask(
508
+ mask_type="continuous", temperature=config.temperature
509
+ )
510
+ metrics["train/sparsity"] = mask_sparsity(mask)
511
+ with self.profile("merge weights"):
512
+ # rescale mask
513
+ for name, m in mask.items():
514
+ mask[name] = m / torch.mean(m)
515
+
516
+ # for inner optimization, we do not optimize the mask, so we detach it
517
+ module.merge_weights(
518
+ task_vector_mask={name: m.detach() for name, m in mask.items()}
519
+ )
520
+
521
+ # ------ inner optimization goes here ------
522
+ module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
523
+ total_loss = None
524
+ for task in self.modelpool.model_names:
525
+ with self.profile("data loading"):
526
+ batch = next(self.get_shuffled_test_loader_iter(task))
527
+ # NOTE: The labels are not allowed to be used during test-time adaptation
528
+ images = batch[0]
529
+ with self.profile("forward pass"):
530
+ logits = self.compute_logits(module, images, task)
531
+ loss = entropy_loss(logits)
532
+ total_loss = loss if total_loss is None else total_loss + loss
533
+
534
+ with self.profile("compute grad"):
535
+ self.fabric.backward(total_loss)
536
+
537
+ with self.profile("base optimizer step"):
538
+ base_optimizer.step()
539
+ base_optimizer.zero_grad()
540
+
541
+ with self.profile("merge weights"):
542
+ module.merge_weights(task_vector_mask=mask)
543
+
544
+ # ------------------------------------------
545
+
546
+ # (2)noise optimization based on the merging model
547
+
548
+ # detach merged state_dict
549
+ merged_state_dict = module._merged_state_dict
550
+ detached_merged_state_dict = {
551
+ k: p.detach() for k, p in merged_state_dict.items()
552
+ }
553
+ module._merged_state_dict = detached_merged_state_dict
554
+
555
+ total_loss = None
556
+ for task_idx, task in enumerate(self.modelpool.model_names):
557
+ with self.profile("data loading"):
558
+ batch = next(self.get_shuffled_test_loader_iter(task))
559
+ # NOTE: The labels are not allowed to be used during test-time adaptation
560
+ images = batch[0]
561
+ perturbed_images = (
562
+ images + self.pertubed_model.perturbed_input[task_idx]
563
+ )
564
+ combined_images = torch.cat((images, perturbed_images), dim=0)
565
+
566
+ with self.profile("forward pass"):
567
+ combined_logits = self.compute_logits(module, combined_images, task)
568
+ logits = combined_logits[: images.size(0)]
569
+ logits_adv = combined_logits[images.size(0) :]
570
+ ori_label = torch.argmax(logits, axis=1).long()
571
+ loss = torch.mean(
572
+ -F.cross_entropy(logits_adv, ori_label, reduction="mean")
573
+ )
574
+ total_loss = loss if total_loss is None else total_loss + loss
575
+
576
+ with self.profile("compute grad"):
577
+ self.fabric.backward(total_loss)
578
+
579
+ with self.profile("batch_opt_adv optimizer step"):
580
+ batch_opt_adv.step()
581
+ batch_opt_adv.zero_grad()
582
+
583
+ # (3)mask optimization
584
+ total_loss = None
585
+ module._merged_state_dict = merged_state_dict
586
+
587
+ for task_idx, task in enumerate(self.modelpool.model_names):
588
+ with self.profile("data loading"), torch.no_grad():
589
+ batch = next(self.get_shuffled_test_loader_iter(task))
590
+ # NOTE: The labels are not allowed to be used during test-time adaptation
591
+ images = batch[0]
592
+ perturbed_images = (
593
+ images + self.pertubed_model.perturbed_input[task_idx]
594
+ )
595
+ perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
596
+ combined_images = torch.cat((images, perturbed_images), dim=0)
597
+
598
+ with self.profile("forward pass"):
599
+ combined_logits = self.compute_logits(module, combined_images, task)
600
+ logits = combined_logits[: images.size(0)]
601
+ logits_adv = combined_logits[images.size(0) :]
602
+
603
+ # # ### regu1
604
+ # ori_label = torch.argmax(logits, axis=1).long()
605
+ # loss_nat = entropy_loss(logits)
606
+ # loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
607
+
608
+ ### regu2
609
+ loss_regu = entropy_loss(logits_adv)
610
+ loss_nat = entropy_loss(logits)
611
+
612
+ loss = loss_nat + self.config.adv_weight * loss_regu
613
+ total_loss = loss if total_loss is None else total_loss + loss
614
+
615
+ with self.profile("compute grad"):
616
+ self.fabric.backward(total_loss)
617
+
618
+ with self.profile("optimizer step"):
619
+ optimizer.step()
620
+ optimizer.zero_grad()
621
+
622
+ if lr_scheduler is not None:
623
+ lr_scheduler.step()
624
+
625
+ # metrics.update({"train/loss": loss.item()})
626
+ metrics.update(
627
+ {
628
+ "train/loss": loss.item(),
629
+ "loss_nat": loss_nat.item(),
630
+ "loss_regu": loss_regu.item(),
631
+ }
632
+ )
633
+ self.fabric.log_dict(metrics, step=step_idx)
634
+ pbar.set_postfix(metrics)
635
+ self.print_profile_summary()
636
+
637
+ if (step_idx + 1) % self.config.save_interval == 0:
638
+ with self.profiler.profile("save checkpoint"):
639
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
640
+ if not os.path.exists(save_dir):
641
+ os.makedirs(save_dir, exist_ok=True)
642
+ save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
643
+ print(f"saving checkpoint to {save_path}")
644
+ state = {"model": mask_model}
645
+ self.fabric.save(save_path, state)
646
+
647
+ # Create or update a symbolic link to the latest checkpoint
648
+ if self.fabric.is_global_zero:
649
+ symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
650
+ if os.path.exists(symlink_path):
651
+ os.remove(symlink_path)
652
+ os.link(os.path.abspath(save_path), symlink_path)
653
+
654
+ self.print_profile_summary()
655
+
656
+ def run_adamerging(self, module: LayerWiseMergedModel, mask):
657
+ module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
658
+ base_optimizer = torch.optim.Adam(
659
+ [module.merge_weight], lr=self.config.adamerging_lr
660
+ )
661
+ module, base_optimizer = self.fabric.setup(module, base_optimizer)
662
+ module.train()
663
+ for step_idx in (
664
+ pbar := tqdm(
665
+ range(
666
+ self.config.max_adamerging_steps if not self.is_debug_mode else 5
667
+ ),
668
+ ("[DEBUG MODE] " if self.is_debug_mode else "")
669
+ + "Concrete AdaMerging AdaMerging (2/2)",
670
+ dynamic_ncols=True,
671
+ disable=not self.fabric.is_global_zero,
672
+ )
673
+ ):
674
+ step_idx = step_idx + self.config.max_steps
675
+ with self.profile("merge weights"):
676
+ module.merge_weights(task_vector_mask=mask)
677
+
678
+ metrics = {}
679
+ total_loss = None
680
+ for task in self.modelpool.model_names:
681
+ with self.profile("data loading"):
682
+ batch = next(self.get_shuffled_test_loader_iter(task))
683
+ # NOTE: The labels are not allowed to be used during test-time adaptation
684
+ images = batch[0]
685
+ with self.profile("forward pass"):
686
+ logits = self.compute_logits(module, images, task)
687
+ loss = entropy_loss(logits)
688
+ total_loss = loss if total_loss is None else total_loss + loss
689
+
690
+ with self.profile("compute grad"):
691
+ self.fabric.backward(total_loss)
692
+
693
+ with self.profile("base optimizer step"):
694
+ base_optimizer.step()
695
+ base_optimizer.zero_grad()
696
+
697
+ metrics.update({"train/loss": loss.item()})
698
+ self.fabric.log_dict(metrics, step=step_idx)
699
+ pbar.set_postfix(metrics)
700
+
701
+ if (step_idx + 1) % self.config.save_interval == 0:
702
+ with self.profiler.profile("save checkpoint"):
703
+ save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
704
+ if not os.path.exists(save_dir):
705
+ os.makedirs(save_dir, exist_ok=True)
706
+ save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
707
+ print(f"saving checkpoint to {save_path}")
708
+ state = {"merge_weight": module.merge_weight}
709
+ self.fabric.save(save_path, state)
710
+
711
+ # Create or update a symbolic link to the latest checkpoint
712
+ if self.fabric.is_global_zero:
713
+ symlink_path = os.path.join(
714
+ save_dir, "merge_weight_latest_checkpoint.pt"
715
+ )
716
+ if os.path.exists(symlink_path):
717
+ os.remove(symlink_path)
718
+ os.link(os.path.abspath(save_path), symlink_path)
719
+
720
+ self.print_profile_summary()
721
+ return module
722
+
723
+ # def run_adamerging(self, module: LayerWiseMergedModel, mask):
724
+ # module.merge_weight.data = deepcopy(self.init_layer_wise_weight)
725
+ # base_optimizer = torch.optim.Adam(
726
+ # [module.merge_weight], lr=self.config.adamerging_lr
727
+ # )
728
+ # module, base_optimizer = self.fabric.setup(module, base_optimizer)
729
+ # module.train()
730
+ # for step_idx in (
731
+ # pbar := tqdm(
732
+ # range(
733
+ # self.config.max_adamerging_steps if not self.is_debug_mode else 5
734
+ # ),
735
+ # ("[DEBUG MODE] " if self.is_debug_mode else "")
736
+ # + "Concrete AdaMerging AdaMerging (2/2)",
737
+ # dynamic_ncols=True,
738
+ # disable=not self.fabric.is_global_zero,
739
+ # )
740
+ # ):
741
+ # step_idx = step_idx + self.config.max_steps
742
+ # with self.profile("merge weights"):
743
+ # module.merge_weights(task_vector_mask=mask)
744
+
745
+ # metrics = {}
746
+ # total_loss = None
747
+ # for task_idx, task in enumerate(self.modelpool.model_names):
748
+ # with self.profile("data loading"), torch.no_grad():
749
+ # batch = next(self.get_shuffled_test_loader_iter(task))
750
+ # # NOTE: The labels are not allowed to be used during test-time adaptation
751
+ # images = batch[0]
752
+ # perturbed_images = images + self.pertubed_model.perturbed_input[task_idx]
753
+ # perturbed_images = torch.clamp(perturbed_images, min=0, max=1)
754
+ # combined_images = torch.cat((images, perturbed_images), dim=0)
755
+
756
+ # with self.profile("forward pass"):
757
+ # combined_logits = self.compute_logits(module, combined_images, task)
758
+ # logits = combined_logits[:images.size(0)]
759
+ # logits_adv = combined_logits[images.size(0):]
760
+
761
+ # # # ### regu1
762
+ # # ori_label = torch.argmax(logits, axis=1).long()
763
+ # # loss_nat = entropy_loss(logits)
764
+ # # loss_regu = torch.mean(-F.cross_entropy(logits_adv, ori_label, reduction='mean'))
765
+
766
+ # ### regu2
767
+ # loss_regu = entropy_loss(logits_adv)
768
+ # loss_nat = entropy_loss(logits)
769
+
770
+ # loss = loss_nat + self.config.adv_weight*loss_regu
771
+ # total_loss = loss if total_loss is None else total_loss + loss
772
+ # metrics.update({"train/loss": loss.item(),"loss_nat": loss_nat.item(),"loss_regu": loss_regu.item()})
773
+
774
+ # self.fabric.log_dict(metrics, step=step_idx)
775
+ # pbar.set_postfix(metrics)
776
+ # self.print_profile_summary()
777
+
778
+ # if (step_idx + 1) % self.config.save_interval == 0:
779
+ # with self.profiler.profile("save checkpoint"):
780
+ # save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
781
+ # if not os.path.exists(save_dir):
782
+ # os.makedirs(save_dir, exist_ok=True)
783
+ # save_path = os.path.join(save_dir, f"merge_weight_{step_idx}.pt")
784
+ # print(f"saving checkpoint to {save_path}")
785
+ # state = {"merge_weight": module.merge_weight}
786
+ # self.fabric.save(save_path, state)
787
+
788
+ # # Create or update a symbolic link to the latest checkpoint
789
+ # if self.fabric.is_global_zero:
790
+ # symlink_path = os.path.join(
791
+ # save_dir, "merge_weight_latest_checkpoint.pt"
792
+ # )
793
+ # if os.path.exists(symlink_path):
794
+ # os.remove(symlink_path)
795
+ # os.link(os.path.abspath(save_path), symlink_path)
796
+
797
+ # self.print_profile_summary()
798
+ # return module
799
+
800
+ def run(self, modelpool: HuggingFaceClipVisionPool):
801
+ self.modelpool = to_modelpool(modelpool)
802
+ config = self.config
803
+ self.log_hyperparams(config, filename="method_config.yaml")
804
+
805
+ with self.profile("setup models"):
806
+ module, mask_model = self.setup_models()
807
+ mask_model: MaskModel = self.fabric.to_device(mask_model)
808
+ module: LayerWiseMergedModel = self.fabric.to_device(module)
809
+ self.pertubed_model = self.fabric.to_device(self.pertubed_model)
810
+ self.setup_zero_shot_classification_head()
811
+
812
+ if config.mask_checkpoint is None:
813
+ self.train_mask(module=module, mask_model=mask_model)
814
+ else:
815
+ if self.fabric.is_global_zero:
816
+ print("loading mask from checkpoint", config.mask_checkpoint)
817
+ self.fabric.load(config.mask_checkpoint, {"model": mask_model})
818
+
819
+ # run adamerging
820
+ with torch.no_grad():
821
+ mask = mask_model.sample_mask(
822
+ mask_type=config.eval_mask_type,
823
+ temperature=config.temperature,
824
+ )
825
+ # rescale mask
826
+ for name, m in mask.items():
827
+ mask[name] = m / torch.mean(m)
828
+ module = self.run_adamerging(module, mask=mask)
829
+
830
+ with torch.no_grad():
831
+ model = module.merge_and_unload(mask)
832
+ return model