fusion-bench 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. fusion_bench/dataset/clip_dataset.py +1 -0
  2. fusion_bench/method/__init__.py +4 -0
  3. fusion_bench/method/adamerging/__init__.py +28 -5
  4. fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
  5. fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
  6. fusion_bench/method/adamerging/utils.py +58 -0
  7. fusion_bench/method/classification/clip_finetune.py +6 -4
  8. fusion_bench/method/classification/image_classification_finetune.py +156 -12
  9. fusion_bench/method/dare/simple_average.py +3 -2
  10. fusion_bench/method/dare/task_arithmetic.py +3 -2
  11. fusion_bench/method/dop/__init__.py +1 -0
  12. fusion_bench/method/dop/dop.py +366 -0
  13. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  14. fusion_bench/method/dop/utils.py +73 -0
  15. fusion_bench/method/simple_average.py +6 -4
  16. fusion_bench/mixins/lightning_fabric.py +9 -0
  17. fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
  18. fusion_bench/modelpool/resnet_for_image_classification.py +285 -4
  19. fusion_bench/models/hf_clip.py +4 -7
  20. fusion_bench/models/hf_utils.py +4 -1
  21. fusion_bench/taskpool/__init__.py +2 -0
  22. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  23. fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
  24. fusion_bench/utils/state_dict_arithmetic.py +91 -10
  25. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/METADATA +9 -3
  26. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/RECORD +140 -77
  27. fusion_bench_config/fabric/auto.yaml +1 -1
  28. fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
  29. fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
  30. fusion_bench_config/fabric_model_fusion.yaml +1 -0
  31. fusion_bench_config/method/adamerging/resnet.yaml +18 -0
  32. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  33. fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
  34. fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
  35. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  36. fusion_bench_config/method/dop/dop.yaml +30 -0
  37. fusion_bench_config/method/dummy.yaml +6 -0
  38. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  39. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  40. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  41. fusion_bench_config/method/linear/expo.yaml +5 -0
  42. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  43. fusion_bench_config/method/linear/llama_expo.yaml +5 -0
  44. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
  45. fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
  46. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
  47. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
  48. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  49. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +6 -1
  50. fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
  51. fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
  52. fusion_bench_config/method/model_recombination.yaml +8 -0
  53. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  54. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  55. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  56. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  57. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  58. fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
  59. fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
  60. fusion_bench_config/method/regmean/regmean.yaml +3 -0
  61. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
  62. fusion_bench_config/method/simple_average.yaml +9 -0
  63. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  64. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  65. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
  66. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  67. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
  68. fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
  69. fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
  70. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
  71. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  72. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  73. fusion_bench_config/method/ties_merging.yaml +3 -0
  74. fusion_bench_config/method/wudi/wudi.yaml +3 -0
  75. fusion_bench_config/model_fusion.yaml +2 -1
  76. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
  77. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
  78. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
  79. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
  80. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
  81. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
  82. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
  83. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
  84. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
  85. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
  86. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
  87. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
  88. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
  89. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
  90. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
  91. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
  92. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
  93. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
  94. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
  95. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
  96. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
  97. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
  98. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
  99. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
  100. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
  101. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
  102. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
  103. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
  104. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
  105. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
  106. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
  107. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
  108. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
  109. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
  110. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
  111. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
  112. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
  113. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
  114. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
  115. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
  116. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
  117. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
  118. fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
  119. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
  120. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
  121. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
  122. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
  123. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
  124. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
  125. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
  126. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
  127. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
  128. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
  129. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
  130. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
  131. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
  132. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
  133. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
  134. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
  135. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
  136. fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
  137. fusion_bench_config/method/clip_finetune.yaml +0 -26
  138. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/WHEEL +0 -0
  139. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/entry_points.txt +0 -0
  140. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/licenses/LICENSE +0 -0
  141. {fusion_bench-0.2.25.dist-info → fusion_bench-0.2.27.dist-info}/top_level.txt +0 -0
@@ -62,6 +62,7 @@ class CLIPDataset(torch.utils.data.Dataset):
62
62
  if self.processor is not None:
63
63
  if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
64
64
  # Apply the processor to the image to get the input tensor
65
+ image = image.convert("RGB") # ensure image is in RGB format
65
66
  inputs = self.processor(images=[image], return_tensors="pt")[
66
67
  "pixel_values"
67
68
  ][0]
@@ -55,6 +55,8 @@ _import_structure = {
55
55
  "GPT2LayerWiseAdaMergingAlgorithm",
56
56
  "LayerWiseAdaMergingForLlamaSFT",
57
57
  "FlanT5LayerWiseAdaMergingAlgorithm",
58
+ "ResNetLayerWiseAdamerging",
59
+ "ResNetTaskWiseAdamerging",
58
60
  ],
59
61
  "pwe_moe": [
60
62
  "PWEMoELinearScalarizationForCLIP",
@@ -70,6 +72,7 @@ _import_structure = {
70
72
  "IsotropicMergingInCommonSubspace",
71
73
  ],
72
74
  "opcm": ["OPCMForCLIP"],
75
+ "dop": ["ContinualDOPForCLIP"],
73
76
  "gossip": [
74
77
  "CLIPLayerWiseGossipAlgorithm",
75
78
  "CLIPTaskWiseGossipAlgorithm",
@@ -212,6 +215,7 @@ if TYPE_CHECKING:
212
215
  from .model_recombination import ModelRecombinationAlgorithm
213
216
  from .model_stock import ModelStock
214
217
  from .opcm import OPCMForCLIP
218
+ from .dop import ContinualDOPForCLIP
215
219
  from .pruning import (
216
220
  MagnitudeDiffPruningAlgorithm,
217
221
  MagnitudePruningForLlama,
@@ -1,6 +1,29 @@
1
1
  # flake8: noqa F401
2
- from .clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm
3
- from .clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
4
- from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
5
- from .gpt2_layer_wise_adamerging import GPT2LayerWiseAdaMergingAlgorithm
6
- from .llama_adamerging import LayerWiseAdaMergingForLlamaSFT
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ _import_structure = {
8
+ "clip_layer_wise_adamerging": ["CLIPLayerWiseAdaMergingAlgorithm"],
9
+ "clip_task_wise_adamerging": ["CLIPTaskWiseAdaMergingAlgorithm"],
10
+ "flan_t5_layer_wise_adamerging": ["FlanT5LayerWiseAdaMergingAlgorithm"],
11
+ "gpt2_layer_wise_adamerging": ["GPT2LayerWiseAdaMergingAlgorithm"],
12
+ "llama_adamerging": ["LayerWiseAdaMergingForLlamaSFT"],
13
+ "resnet_adamerging": ["ResNetLayerWiseAdamerging", "ResNetTaskWiseAdamerging"],
14
+ }
15
+
16
+ if TYPE_CHECKING:
17
+ from .clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm
18
+ from .clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
19
+ from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
20
+ from .gpt2_layer_wise_adamerging import GPT2LayerWiseAdaMergingAlgorithm
21
+ from .llama_adamerging import LayerWiseAdaMergingForLlamaSFT
22
+ from .resnet_adamerging import ResNetLayerWiseAdamerging, ResNetTaskWiseAdamerging
23
+
24
+ else:
25
+ sys.modules[__name__] = LazyImporter(
26
+ __name__,
27
+ globals()["__file__"],
28
+ _import_structure,
29
+ )
@@ -0,0 +1,279 @@
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Dict, Iterator, Optional, Union, override
4
+
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch import nn
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from fusion_bench import (
12
+ BaseAlgorithm,
13
+ LightningFabricMixin,
14
+ auto_register_config,
15
+ get_rankzero_logger,
16
+ instantiate,
17
+ )
18
+ from fusion_bench.constants import RuntimeConstants
19
+ from fusion_bench.dataset import CLIPDataset
20
+ from fusion_bench.modelpool import ResNetForImageClassificationPool
21
+ from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
22
+ from fusion_bench.models.wrappers.task_wise_fusion import TaskWiseMergedModel
23
+ from fusion_bench.utils import load_tensor_from_file
24
+ from fusion_bench.utils.data import InfiniteDataLoader
25
+
26
+ from .entropy_loss import entropy_loss
27
+ from .utils import construct_layer_wise_merged_model, construct_task_wise_merged_model
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import ResNetForImageClassification, ResNetModel
31
+
32
+ log = get_rankzero_logger(__name__)
33
+
34
+
35
+ @auto_register_config
36
+ class _ResNetAdaMergingBase(
37
+ ABC,
38
+ LightningFabricMixin,
39
+ BaseAlgorithm,
40
+ ):
41
+ classification_heads: Dict[str, nn.Module]
42
+ shuffled_test_loader_iters: Dict[str, Iterator]
43
+
44
+ def __init__(
45
+ self,
46
+ max_steps: int,
47
+ optimizer: DictConfig,
48
+ lr_scheduler: DictConfig,
49
+ dataloader_kwargs: DictConfig,
50
+ init_values: Optional[float],
51
+ clamp_weights: bool = False,
52
+ tie_weights: bool = True,
53
+ strict: bool = False,
54
+ resume_weights_path: Union[str, None] = None,
55
+ **kwargs,
56
+ ):
57
+ super().__init__(**kwargs)
58
+ if RuntimeConstants.debug:
59
+ log.info("Debug mode is on, setting max_steps to 10")
60
+ self.max_steps = 10
61
+
62
+ @override
63
+ def run(self, modelpool: ResNetForImageClassificationPool):
64
+ self.modelpool = modelpool
65
+
66
+ # setup models
67
+ wrapped_model = self.setup_wrapped_model(modelpool)
68
+
69
+ # if max_steps <= 0, skip training and return the merged model directly
70
+ # this can be used to evaluate the merging weights loaded from `resume_weights_path`
71
+ if self.max_steps <= 0:
72
+ # skip_training
73
+ return wrapped_model.merge_and_unload()
74
+
75
+ # setup dataloaders
76
+ self.setup_dataloaders()
77
+
78
+ # configure optimizer and lr_scheduler
79
+ optimizer = instantiate(self.optimizer, params=[wrapped_model.merge_weight])
80
+ if self.lr_scheduler is not None:
81
+ lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
82
+ else:
83
+ lr_scheduler = None
84
+
85
+ wrapped_model, optimizer = self.fabric.setup(wrapped_model, optimizer)
86
+ wrapped_model = self.test_time_adaptation(
87
+ wrapped_model, optimizer, lr_scheduler
88
+ )
89
+
90
+ # save merging weights
91
+ if self.log_dir is not None:
92
+ self.fabric.save(
93
+ os.path.join(self.log_dir, "checkpoints", "merge_weight.ckpt"),
94
+ {"merge_weight": wrapped_model.merge_weight},
95
+ )
96
+
97
+ merged_model = wrapped_model.merge_and_unload()
98
+ if self.log_dir is not None:
99
+ modelpool.save_model(
100
+ merged_model,
101
+ os.path.join(self.log_dir, "checkpoints", "merged_model"),
102
+ algorithm_config=self.config,
103
+ description="Merged ResNet model using AdaMerging (E Yang, 2023).",
104
+ )
105
+
106
+ return merged_model
107
+
108
+ def test_time_adaptation(
109
+ self,
110
+ wrapped_model: TaskWiseMergedModel,
111
+ optimizer: torch.optim.Optimizer,
112
+ lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler],
113
+ ):
114
+ model_names = self.modelpool.model_names
115
+ wrapped_model.train()
116
+ wrapped_model.merge_weights()
117
+
118
+ for step_idx in tqdm(
119
+ range(self.max_steps),
120
+ disable=not self.fabric.is_global_zero,
121
+ dynamic_ncols=True,
122
+ ):
123
+ metrics = {"tta/total_loss": 0.0}
124
+ for task in model_names:
125
+ batch = next(self.get_shuffled_test_loader_iter(task))
126
+ logits = self.compute_logits(wrapped_model, batch[0], task)
127
+ loss = entropy_loss(logits)
128
+ metrics[f"tta/{task}_loss"] = loss.item()
129
+ metrics["tta/total_loss"] += loss.item()
130
+ self.fabric.backward(loss, retain_graph=True)
131
+
132
+ optimizer.step()
133
+ optimizer.zero_grad()
134
+ wrapped_model.merge_weights() # merge weights for the next step
135
+ if lr_scheduler is not None:
136
+ lr_scheduler.step()
137
+
138
+ self.fabric.log_dict(metrics=metrics, step=step_idx)
139
+
140
+ return wrapped_model
141
+
142
+ def compute_logits(
143
+ self, module: Union["ResNetModel", nn.Module], images: torch.Tensor, task: str
144
+ ) -> torch.Tensor:
145
+ if self.modelpool.type == "transformers":
146
+ outputs = module(images, return_dict=True)
147
+ pooled_output = outputs.pooler_output
148
+ logits = self.classification_heads[task](pooled_output)
149
+ return logits
150
+ else:
151
+ raise NotImplementedError(
152
+ f"Model type {self.modelpool.type} is not supported."
153
+ )
154
+
155
+ def setup_dataloaders(self):
156
+ dataloader_kwargs = dict(self.dataloader_kwargs)
157
+ dataloader_kwargs["shuffle"] = True # ensure shuffling for TTA
158
+ processor = self.modelpool.load_processor()
159
+ for task in self.modelpool.test_dataset_names:
160
+ test_dataset = self.modelpool.load_test_dataset(task)
161
+ test_dataset = CLIPDataset(test_dataset, processor=processor)
162
+ test_loader = DataLoader(test_dataset, **dataloader_kwargs)
163
+ self.shuffled_test_loader_iters[task] = iter(
164
+ InfiniteDataLoader(test_loader)
165
+ )
166
+
167
+ def get_shuffled_test_loader_iter(self, task: str):
168
+ return self.shuffled_test_loader_iters[task]
169
+
170
+ @abstractmethod
171
+ def setup_wrapped_model(
172
+ self, modelpool: ResNetForImageClassificationPool
173
+ ) -> Union[TaskWiseMergedModel, LayerWiseMergedModel]:
174
+ """
175
+ Setup the wrapped merged model.
176
+
177
+ Args:
178
+ modelpool (ResNetForImageClassificationPool): The model pool containing pretrained and finetuned models.
179
+
180
+ Returns:
181
+ Union[TaskWiseMergedModel, LayerWiseMergedModel] : The wrapped merged model.
182
+ """
183
+ pass
184
+
185
+
186
+ class ResNetTaskWiseAdamerging(_ResNetAdaMergingBase):
187
+ @torch.no_grad()
188
+ def setup_wrapped_model(self, modelpool: ResNetForImageClassificationPool):
189
+ pretrained_model = modelpool.load_pretrained_model()
190
+ finetuned_models = dict(modelpool.named_models())
191
+
192
+ if modelpool.type == "transformers":
193
+ pretrained_model: "ResNetForImageClassification"
194
+ finetuned_models: Dict[str, "ResNetForImageClassification"]
195
+ for model_name in finetuned_models:
196
+ self.classification_heads[model_name] = finetuned_models[
197
+ model_name
198
+ ].classifier
199
+ # fix the classification head during merging and move to device
200
+ self.classification_heads[model_name].requires_grad_(False)
201
+ pretrained_backbone: "ResNetModel" = pretrained_model.resnet
202
+ finetuned_backbones = [
203
+ finetuned_models[model_name].resnet for model_name in finetuned_models
204
+ ]
205
+ else:
206
+ raise NotImplementedError(f"Model type {modelpool.type} is not supported.")
207
+
208
+ wrapped_model = construct_task_wise_merged_model(
209
+ pretrained_model=pretrained_backbone,
210
+ finetuned_models=finetuned_backbones,
211
+ clamp_weights=self.clamp_weights,
212
+ tie_weights=self.tie_weights,
213
+ strict=self.strict,
214
+ )
215
+
216
+ if self.init_values is not None:
217
+ log.info(f"Initializing merging weights to {self.init_values}")
218
+ wrapped_model.merge_weight.data.fill_(self.init_values)
219
+
220
+ # load merging weights if provided
221
+ if self.resume_weights_path is not None:
222
+ merging_weights = load_tensor_from_file(
223
+ self.resume_weights_path, device="cpu"
224
+ )
225
+ log.info(f"Loaded merging weights from {self.resume_weights_path}")
226
+ assert merging_weights.shape == wrapped_model.merge_weight.shape, (
227
+ f"Merging weights shape {merging_weights.shape} does not match "
228
+ f"model's merge_weight shape {wrapped_model.merge_weight.shape}."
229
+ )
230
+ wrapped_model.merge_weight.data = merging_weights
231
+ return wrapped_model
232
+
233
+
234
+ class ResNetLayerWiseAdamerging(_ResNetAdaMergingBase):
235
+ @torch.no_grad()
236
+ def setup_wrapped_model(self, modelpool: ResNetForImageClassificationPool):
237
+ pretrained_model = modelpool.load_pretrained_model()
238
+ finetuned_models = dict(modelpool.named_models())
239
+
240
+ if modelpool.type == "transformers":
241
+ pretrained_model: "ResNetForImageClassification"
242
+ finetuned_models: Dict[str, "ResNetForImageClassification"]
243
+ for model_name in finetuned_models:
244
+ self.classification_heads[model_name] = finetuned_models[
245
+ model_name
246
+ ].classifier
247
+ # fix the classification head during merging and move to device
248
+ self.classification_heads[model_name].requires_grad_(False)
249
+ pretrained_backbone: "ResNetModel" = pretrained_model.resnet
250
+ finetuned_backbones = [
251
+ finetuned_models[model_name].resnet for model_name in finetuned_models
252
+ ]
253
+ else:
254
+ raise NotImplementedError(f"Model type {modelpool.type} is not supported.")
255
+
256
+ wrapped_model = construct_layer_wise_merged_model(
257
+ pretrained_model=pretrained_backbone,
258
+ finetuned_models=finetuned_backbones,
259
+ clamp_weights=self.clamp_weights,
260
+ tie_weights=self.tie_weights,
261
+ strict=self.strict,
262
+ )
263
+
264
+ if self.init_values is not None:
265
+ log.info(f"Initializing merging weights to {self.init_values}")
266
+ wrapped_model.merge_weight.data.fill_(self.init_values)
267
+
268
+ # load merging weights if provided
269
+ if self.resume_weights_path is not None:
270
+ merging_weights = load_tensor_from_file(
271
+ self.resume_weights_path, device="cpu"
272
+ )
273
+ log.info(f"Loaded merging weights from {self.resume_weights_path}")
274
+ assert merging_weights.shape == wrapped_model.merge_weight.shape, (
275
+ f"Merging weights shape {merging_weights.shape} does not match "
276
+ f"model's merge_weight shape {wrapped_model.merge_weight.shape}."
277
+ )
278
+ wrapped_model.merge_weight.data = merging_weights
279
+ return wrapped_model
@@ -18,21 +18,9 @@ from fusion_bench.models.wrappers.task_wise_fusion import (
18
18
  get_task_wise_weights,
19
19
  )
20
20
 
21
- log = logging.getLogger(__name__)
22
-
23
-
24
- def entropy_loss(logits: Tensor) -> Tensor:
25
- """
26
- Compute the entropy loss of a set of logits.
21
+ from .entropy_loss import entropy_loss
27
22
 
28
- Args:
29
- logits (Tensor): The logits to compute the entropy loss of.
30
-
31
- Returns:
32
- Tensor: The entropy loss of the logits.
33
- """
34
- probs = torch.softmax(logits, dim=-1)
35
- return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
23
+ log = logging.getLogger(__name__)
36
24
 
37
25
 
38
26
  class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
@@ -1,4 +1,9 @@
1
+ from typing import List
2
+
1
3
  import torch
4
+ import torch.nn as nn
5
+
6
+ from fusion_bench.utils.type import TorchModelType
2
7
 
3
8
 
4
9
  def get_memory_usage(desc):
@@ -13,3 +18,56 @@ def get_memory_usage(desc):
13
18
  return (
14
19
  f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
15
20
  )
21
+
22
+
23
+ @torch.no_grad()
24
+ def construct_task_wise_merged_model(
25
+ pretrained_model: TorchModelType,
26
+ finetuned_models: List[TorchModelType],
27
+ clamp_weights: bool = False,
28
+ tie_weights: bool = True,
29
+ strict: bool = False,
30
+ ):
31
+ from fusion_bench.models.wrappers.task_wise_fusion import (
32
+ TaskWiseMergedModel,
33
+ get_task_wise_weights,
34
+ )
35
+
36
+ merging_weights = get_task_wise_weights(num_models=len(finetuned_models))
37
+ module = TaskWiseMergedModel(
38
+ task_wise_weight=merging_weights,
39
+ pretrained_model=pretrained_model,
40
+ finetuned_models=finetuned_models,
41
+ clamp_weights=clamp_weights,
42
+ tie_weights=tie_weights,
43
+ strict=strict,
44
+ )
45
+ return module
46
+
47
+
48
+ @torch.no_grad()
49
+ def construct_layer_wise_merged_model(
50
+ pretrained_model: TorchModelType,
51
+ finetuned_models: List[TorchModelType],
52
+ clamp_weights: bool = False,
53
+ tie_weights: bool = True,
54
+ strict: bool = False,
55
+ ):
56
+ from fusion_bench.models.wrappers.layer_wise_fusion import (
57
+ LayerWiseMergedModel,
58
+ get_layer_wise_weights,
59
+ )
60
+
61
+ merging_weights = get_layer_wise_weights(
62
+ num_models=len(finetuned_models),
63
+ num_layers=len([p for p in pretrained_model.parameters() if p.requires_grad]),
64
+ )
65
+ module = LayerWiseMergedModel(
66
+ layer_wise_weight=merging_weights,
67
+ pretrained_model=pretrained_model,
68
+ finetuned_models=finetuned_models,
69
+ clamp_weights=clamp_weights,
70
+ tie_weights=tie_weights,
71
+ strict=strict,
72
+ )
73
+ return module
@@ -5,8 +5,8 @@ Fine-tune CLIP-ViT-B/32:
5
5
 
6
6
  ```bash
7
7
  fusion_bench \
8
- method=clip_finetune \
9
- modelpool=clip-vit-base-patch32_mtl \
8
+ method=classification/clip_finetune \
9
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
10
10
  taskpool=dummy
11
11
  ```
12
12
 
@@ -15,12 +15,14 @@ Fine-tune CLIP-ViT-L/14 on eight GPUs with a per-device per-task batch size of 2
15
15
  ```bash
16
16
  fusion_bench \
17
17
  fabric.devices=8 \
18
- method=clip_finetune \
18
+ method=classification/clip_finetune \
19
19
  method.batch_size=2 \
20
- modelpool=clip-vit-base-patch32_mtl \
20
+ modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
21
21
  modelpool.models.0.path=openai/clip-vit-large-patch14 \
22
22
  taskpool=dummy
23
23
  ```
24
+
25
+ See `examples/clip_finetune` for more details.
24
26
  """
25
27
 
26
28
  import os