fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,476 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from collections import defaultdict
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import lightning.fabric
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from open_clip.model import ResidualAttentionBlock
15
+ from torch import Tensor, nn
16
+ from tqdm.auto import tqdm
17
+
18
+ from fusion_bench import BaseAlgorithm
19
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
20
+ from fusion_bench.method.task_arithmetic import task_arithmetic_merge
21
+ from fusion_bench.mixins import OpenCLIPClassificationMixin, SimpleProfilerMixin
22
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
23
+ from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
24
+ from fusion_bench.utils import print_parameters, timeit_context
25
+ from fusion_bench.utils.data import InfiniteDataLoader
26
+
27
+ from .module import ParetoWeightEnsemblingModule
28
+ from .phn.solvers import EPOSolver
29
+ from .utils import generate_simplex_grid
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ class PWEMoEAlgorithmForOpenCLIP(
35
+ BaseAlgorithm,
36
+ SimpleProfilerMixin,
37
+ OpenCLIPClassificationMixin,
38
+ ):
39
+ modelpool: OpenCLIPVisionModelPool
40
+
41
+ #! === Training & Validation Data ===
42
+ # setup the datasets and loaders by calling `load_datasets`
43
+ train_datasets: Dict[str, CLIPDataset]
44
+ train_loaders: Dict[str, torch.utils.data.DataLoader]
45
+ train_loader_iters: Dict[str, Iterator[Tuple[torch.Tensor, torch.Tensor]]]
46
+
47
+ test_datasets: Dict[str, CLIPDataset]
48
+ test_loaders: Dict[str, torch.utils.data.DataLoader]
49
+
50
+ def __init__(
51
+ self,
52
+ *,
53
+ #! === Model Architecture Arguments ===
54
+ partial: bool,
55
+ init_lambda: float,
56
+ router_hidden_layers: int,
57
+ checkpoint_path: str,
58
+ #! === Training Arguments ===
59
+ run_train: bool,
60
+ num_steps: int,
61
+ save_interval: int,
62
+ lr: float,
63
+ alpha: float,
64
+ dataloader_kwargs: DictConfig,
65
+ #! === Evaluation Arguments ===
66
+ run_eval: bool,
67
+ num_evaluation_samples: Union[str, int],
68
+ quick_evaluation: bool,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(**kwargs)
72
+ self.partial = partial
73
+ self.init_lambda = init_lambda
74
+ self.router_hidden_layers = router_hidden_layers
75
+ self.lr = lr
76
+ self.num_steps = num_steps
77
+ self.save_interval = save_interval
78
+ self.alpha = alpha
79
+ self.checkpoint_path = checkpoint_path
80
+ self._dataloader_kwargs = dataloader_kwargs
81
+ self.run_train = run_train
82
+ self.run_eval = run_eval
83
+ self.num_evaluation_samples = num_evaluation_samples
84
+ self.quick_evaluation = quick_evaluation
85
+
86
+ def run(self, modelpool: OpenCLIPVisionModelPool):
87
+ self.modelpool = modelpool
88
+
89
+ # setup the MoE model
90
+ model = self.load_model()
91
+ if self.checkpoint_path is not None:
92
+ self.fabric.load(self.checkpoint_path, {"model": model})
93
+
94
+ # setup dataloaders
95
+ self.load_datasets()
96
+
97
+ if self.run_train:
98
+ model = self.train()
99
+ if self.run_eval:
100
+ self.evaluate(model)
101
+ return model
102
+
103
+ @torch.no_grad()
104
+ def load_model(self):
105
+ modelpool = self.modelpool
106
+
107
+ # load models and classification heads
108
+ pretrained_model: ImageEncoder = self.modelpool.load_pretrained_model()
109
+ log.info("pretrained model statistics:")
110
+ print_parameters(pretrained_model, print_fn=log.info)
111
+
112
+ finetuned_models: Dict[str, ImageEncoder] = {}
113
+ for model_name in self.modelpool.model_names:
114
+ finetuned_models[model_name] = modelpool.load_model(model_name)
115
+
116
+ classification_heads: Dict[str, ClassificationHead] = {}
117
+ for model_name in self.modelpool.model_names:
118
+ classification_heads[model_name] = modelpool.load_classification_head(
119
+ model_name
120
+ )
121
+ self.classification_heads = classification_heads
122
+
123
+ self.train_processor = modelpool.train_processor
124
+ self.test_processor = modelpool.test_processor
125
+
126
+ with timeit_context("Building the MoE model"):
127
+ model = deepcopy(pretrained_model)
128
+
129
+ if self.partial:
130
+ log.info("Weight ensembling only the MLPs")
131
+ # weight ensembling only the MLPs, merge the remaining layers using task arithmetic
132
+ model = task_arithmetic_merge(
133
+ pretrained_model=model,
134
+ finetuned_models=list(finetuned_models.values()),
135
+ scaling_factor=self.init_lambda,
136
+ inplace=True,
137
+ )
138
+
139
+ # fix all parameters
140
+ model.requires_grad_(False)
141
+
142
+ for layer_idx in tqdm(
143
+ range(model.model.visual.transformer.layers), desc="Upscaling MLPs"
144
+ ):
145
+ resblock: ResidualAttentionBlock = (
146
+ model.model.visual.transformer.resblocks[layer_idx]
147
+ )
148
+ resblock.mlp = ParetoWeightEnsemblingModule(
149
+ base_model=cast(
150
+ ResidualAttentionBlock,
151
+ pretrained_model.model.visual.transformer.resblocks[
152
+ layer_idx
153
+ ],
154
+ ).mlp,
155
+ expert_models=[
156
+ cast(
157
+ ResidualAttentionBlock,
158
+ m.model.visual.transformer.resblocks[layer_idx],
159
+ ).mlp
160
+ for m in finetuned_models.values()
161
+ ],
162
+ init_lambda=self.init_lambda,
163
+ fix_base_model_and_experts=True,
164
+ router_hidden_layers=self.router_hidden_layers,
165
+ )
166
+ else:
167
+ log.info("Weight ensembling all the layers")
168
+ # weight ensembling all the layers, merge the remaining layers using task arithmetic
169
+ model = task_arithmetic_merge(
170
+ pretrained_model=model,
171
+ finetuned_models=list(finetuned_models.values()),
172
+ scaling_factor=self.init_lambda,
173
+ inplace=True,
174
+ )
175
+ # fix all parameters
176
+ model.requires_grad_(False)
177
+
178
+ for name in [
179
+ "conv1",
180
+ "ln_pre",
181
+ "ln_post",
182
+ # "class_embedding",
183
+ # "positional_embedding",
184
+ ]:
185
+ setattr(
186
+ model.model.visual,
187
+ name,
188
+ ParetoWeightEnsemblingModule(
189
+ base_model=getattr(pretrained_model.model.visual, name),
190
+ expert_models=[
191
+ getattr(m.model.visual, name)
192
+ for m in finetuned_models.values()
193
+ ],
194
+ init_lambda=self.init_lambda,
195
+ fix_base_model_and_experts=True,
196
+ router_hidden_layers=self.router_hidden_layers,
197
+ ),
198
+ )
199
+ for layer_idx in tqdm(
200
+ range(model.model.visual.transformer.layers),
201
+ desc="Upscaling the transformer layers",
202
+ ):
203
+ for name in ["ln_1", "attn", "ln_attn", "ln_2", "mlp"]:
204
+ setattr(
205
+ model.model.visual.transformer.resblocks[layer_idx],
206
+ name,
207
+ ParetoWeightEnsemblingModule(
208
+ base_model=getattr(
209
+ cast(
210
+ ResidualAttentionBlock,
211
+ pretrained_model.model.visual.transformer.resblocks[
212
+ layer_idx
213
+ ],
214
+ ),
215
+ name,
216
+ ),
217
+ expert_models=[
218
+ getattr(
219
+ cast(
220
+ ResidualAttentionBlock,
221
+ m.model.visual.transformer.resblocks[
222
+ layer_idx
223
+ ],
224
+ ),
225
+ name,
226
+ )
227
+ for m in finetuned_models.values()
228
+ ],
229
+ init_lambda=self.init_lambda,
230
+ fix_base_model_and_experts=True,
231
+ router_hidden_layers=self.router_hidden_layers,
232
+ ),
233
+ )
234
+ for name in ["token_embedding", "ln_final"]:
235
+ setattr(
236
+ model.model,
237
+ name,
238
+ ParetoWeightEnsemblingModule(
239
+ base_model=getattr(pretrained_model.model, name),
240
+ expert_models=[
241
+ getattr(m.model, name)
242
+ for m in finetuned_models.values()
243
+ ],
244
+ init_lambda=self.init_lambda,
245
+ fix_base_model_and_experts=True,
246
+ router_hidden_layers=self.router_hidden_layers,
247
+ ),
248
+ )
249
+
250
+ self.model = model
251
+ print_parameters(model, print_fn=log.info)
252
+ return model
253
+
254
+ def load_datasets(self):
255
+ modelpool = self.modelpool
256
+
257
+ # setup the train datasets and loaders
258
+ train_datasets = {}
259
+ train_loaders = {}
260
+ train_loader_iters = {}
261
+ for dataset_name in modelpool.train_dataset_names:
262
+ train_datasets[dataset_name] = modelpool.load_train_dataset(dataset_name)
263
+ train_datasets[dataset_name] = CLIPDataset(
264
+ train_datasets[dataset_name], self.train_processor
265
+ )
266
+ # sanity check
267
+ assert isinstance(train_datasets[dataset_name][0], tuple)
268
+
269
+ # setup the train loaders
270
+ train_loaders[dataset_name] = torch.utils.data.DataLoader(
271
+ train_datasets[dataset_name],
272
+ shuffle=True,
273
+ drop_last=True,
274
+ **self._dataloader_kwargs,
275
+ )
276
+ train_loaders[dataset_name] = self.fabric.setup_dataloaders(
277
+ train_loaders[dataset_name]
278
+ )
279
+ train_loaders[dataset_name] = InfiniteDataLoader(
280
+ train_loaders[dataset_name]
281
+ )
282
+
283
+ # setup the train loader iterators
284
+ train_loader_iters[dataset_name] = iter(train_loaders[dataset_name])
285
+
286
+ self.train_datasets = train_datasets
287
+ self.train_loaders = train_loaders
288
+ self.train_loader_iters = train_loader_iters
289
+
290
+ # setup the test datasets and loaders
291
+ test_datasets = {}
292
+ test_loaders = {}
293
+ for dataset_name in modelpool.test_dataset_names:
294
+ test_datasets[dataset_name] = modelpool.load_test_dataset(dataset_name)
295
+ test_datasets[dataset_name] = CLIPDataset(
296
+ test_datasets[dataset_name], self.test_processor
297
+ )
298
+ test_loaders[dataset_name] = torch.utils.data.DataLoader(
299
+ test_datasets[dataset_name],
300
+ shuffle=False,
301
+ **self._dataloader_kwargs,
302
+ )
303
+ test_loaders[dataset_name] = self.fabric.setup_dataloaders(
304
+ test_loaders[dataset_name]
305
+ )
306
+
307
+ self.test_datasets = test_datasets
308
+ self.test_loaders = test_loaders
309
+
310
+ def compute_loss(self, model: ImageEncoder, ray: Tensor):
311
+ losses = []
312
+ for dataset_idx, dataset_name in enumerate(self.train_datasets):
313
+ batch = next(self.train_loader_iters[dataset_name])
314
+ x, y = batch
315
+
316
+ features = model(x)
317
+ logits = self.classification_heads[dataset_name](features)
318
+
319
+ _loss = F.cross_entropy(logits, y)
320
+ losses.append(_loss)
321
+
322
+ loss = self.aggregate_loss(model, ray, losses)
323
+ return loss
324
+
325
+ @abstractmethod
326
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
327
+ pass
328
+
329
+ def train(self):
330
+ # setup the model
331
+ num_objectives = len(self.modelpool.model_names)
332
+ model = deepcopy(self.model)
333
+ self.classification_heads = {
334
+ t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
335
+ }
336
+
337
+ # set up the optimizer and learning rate scheduler
338
+ optimizer = torch.optim.Adam(
339
+ filter(lambda p: p.requires_grad, model.parameters()),
340
+ lr=self.lr,
341
+ )
342
+ model, optimizer = self.fabric.setup(model, optimizer)
343
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
344
+ optimizer=optimizer, T_max=self.num_steps, eta_min=self.lr * 0.1
345
+ )
346
+
347
+ model.train()
348
+ device = self.fabric.device
349
+ for step_idx in tqdm(
350
+ range(1, 1 + self.num_steps), "training", dynamic_ncols=True
351
+ ):
352
+ # sample a preference ray
353
+ ray = torch.from_numpy(
354
+ np.random.dirichlet((self.alpha,) * num_objectives, 1)
355
+ .astype(np.float32)
356
+ .flatten()
357
+ ).to(device)
358
+ ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
359
+
360
+ loss = self.compute_loss(model, ray)
361
+
362
+ optimizer.zero_grad()
363
+ self.fabric.backward(loss)
364
+ optimizer.step()
365
+
366
+ lr_scheduler.step()
367
+
368
+ self.fabric.log("loss", loss.item(), step=step_idx)
369
+
370
+ if step_idx % self.save_interval == 0 or step_idx == self.num_steps:
371
+ ckpt_dir = Path(self.log_dir) / "checkpoints"
372
+ ckpt_dir.mkdir(exist_ok=True, parents=True)
373
+ self.fabric.save(
374
+ ckpt_dir / f"model_step={step_idx}.ckpt",
375
+ {"model": model},
376
+ )
377
+ return model
378
+
379
+ def evaluate(self, model):
380
+ results = defaultdict(list)
381
+
382
+ num_objectives = len(self.modelpool.model_names)
383
+ device = self.fabric.device
384
+ self.classification_heads = {
385
+ t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
386
+ }
387
+
388
+ if not lightning.fabric.is_wrapped(model):
389
+ model = self.fabric.setup_module(model)
390
+ model.eval()
391
+
392
+ if self.num_evaluation_samples == "equal_weight":
393
+ uniform_grid = np.array(
394
+ [[1 / num_objectives] * num_objectives], dtype=np.float32
395
+ )
396
+ else:
397
+ uniform_grid = generate_simplex_grid(
398
+ num_objectives, self.num_evaluation_samples
399
+ )
400
+ for ray_idx, ray in tqdm(enumerate(uniform_grid), "evaluating samples"):
401
+ results["ray_idx"].append(ray_idx)
402
+ # sample a preference ray
403
+ for i in range(len(ray)):
404
+ results[f"ray_{i}"].append(ray[i])
405
+ ray = torch.from_numpy(ray).to(device)
406
+ ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
407
+
408
+ accs = []
409
+ for dataset_idx, dataset_name in enumerate(
410
+ tqdm(
411
+ self.modelpool.test_dataset_names,
412
+ "evaluating datasets",
413
+ leave=False,
414
+ )
415
+ ):
416
+ test_loader = self.test_loaders[dataset_name]
417
+ TOTAL_CORRECT = 0
418
+ TOTAL_COUNT = 0
419
+ for batch_idx, batch in enumerate(
420
+ pbar := tqdm(
421
+ test_loader,
422
+ f"evaluate {dataset_name}",
423
+ leave=False,
424
+ )
425
+ ):
426
+ x, y = batch
427
+
428
+ features = model(x)
429
+ logits = self.classification_heads[dataset_name](features)
430
+ preds = logits.argmax(-1)
431
+
432
+ correct = (preds == y).sum().item()
433
+ TOTAL_CORRECT += correct
434
+ TOTAL_COUNT += len(y)
435
+ acc = TOTAL_CORRECT / TOTAL_COUNT
436
+ pbar.set_postfix_str(f"acc={acc:.2f}")
437
+
438
+ if self.quick_evaluation and batch_idx > 20:
439
+ break
440
+ results[dataset_name].append(acc)
441
+ accs.append(acc)
442
+
443
+ # compute the average accuracy
444
+ if "average" not in self.modelpool.test_dataset_names:
445
+ results["average"].append(np.mean(accs))
446
+
447
+ (df := pd.DataFrame(results)).to_csv(
448
+ Path(self.log_dir) / "result.csv", index=False
449
+ )
450
+ log.info(df)
451
+
452
+
453
+ class PWEMoELinearScalarizationForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
454
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
455
+ loss = 0
456
+ for r, l in zip(ray, losses):
457
+ loss += r * l
458
+ return loss
459
+
460
+
461
+ class PWEMoEExactParetoOptimalForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
462
+ epo_solver: Optional[EPOSolver] = None
463
+
464
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
465
+ if self.epo_solver is None:
466
+ num_objectives = len(self.modelpool.model_names)
467
+ self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
468
+ epo_solver = self.epo_solver
469
+
470
+ losses = torch.stack(losses)
471
+ loss = epo_solver.get_weighted_loss(
472
+ losses,
473
+ ray,
474
+ tuple(filter(lambda p: p.requires_grad, model.parameters())),
475
+ )
476
+ return loss
@@ -13,6 +13,7 @@ from torch import Tensor, nn
13
13
  from tqdm.autonotebook import tqdm
14
14
 
15
15
  from fusion_bench.method import BaseAlgorithm
16
+ from fusion_bench.mixins import SimpleProfilerMixin
16
17
  from fusion_bench.modelpool import BaseModelPool
17
18
 
18
19
  log = logging.getLogger(__name__)
@@ -279,7 +280,7 @@ def regmean_merging(
279
280
  return merged_params
280
281
 
281
282
 
282
- class RegMeanAlgorithm(BaseAlgorithm):
283
+ class RegMeanAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
283
284
  _include_module_type = [nn.Linear]
284
285
  _config_mapping = {
285
286
  "num_regmean_examples": "num_regmean_examples",
@@ -342,24 +343,31 @@ class RegMeanAlgorithm(BaseAlgorithm):
342
343
  )
343
344
  assert len(linear_modules_to_merge) > 0, "No linear modules to merge"
344
345
 
345
- regmean_weights = self.get_regmean_weights(
346
- name,
347
- model,
348
- train_dataset=modelpool.load_train_dataset(name),
349
- linear_modules_to_merge=linear_modules_to_merge,
350
- )
351
- models_to_merge_regmean_weights_list.append(regmean_weights)
346
+ with (
347
+ self.profile("merging models"),
348
+ self.profile("computing regmean weights"),
349
+ ):
350
+ regmean_weights = self.get_regmean_weights(
351
+ name,
352
+ model,
353
+ train_dataset=modelpool.load_train_dataset(name),
354
+ linear_modules_to_merge=linear_modules_to_merge,
355
+ )
356
+ models_to_merge_regmean_weights_list.append(regmean_weights)
357
+
358
+ with self.profile("merging models"):
359
+ # merging with regmean weights
360
+ merged_params = merging_with_regmean_weights(
361
+ models_to_merge_param_dict=models_to_merge_param_dict,
362
+ models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
363
+ reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
364
+ weight_transpose=self.config.get("weight_transpose", True),
365
+ )
352
366
 
353
- # merging with regmean weights
354
- merged_params = merging_with_regmean_weights(
355
- models_to_merge_param_dict=models_to_merge_param_dict,
356
- models_to_merge_regmean_weights_list=models_to_merge_regmean_weights_list,
357
- reduce_non_diagonal_ratio=self.reduce_non_diagonal_ratio,
358
- weight_transpose=self.config.get("weight_transpose", True),
359
- )
367
+ merged_model = modelpool.load_model("_pretrained_")
368
+ merged_model.load_state_dict(merged_params, strict=False)
360
369
 
361
- merged_model = modelpool.load_model("_pretrained_")
362
- merged_model.load_state_dict(merged_params, strict=False)
370
+ self.print_profile_summary()
363
371
  return merged_model
364
372
 
365
373
  def on_regmean_start(self):
@@ -1,3 +1,3 @@
1
1
  # flake8: noqa F401
2
2
  from .singular_projection_merging import SingularProjectionMergingAlgorithm
3
- from .smile_upscaling import SmileUpscalingAlgorithm
3
+ from .smile_upscaling import SmileMoELinear, SmileUpscalingAlgorithm
@@ -442,16 +442,19 @@ class SmileUpscalingAlgorithm(
442
442
  print_parameters(model)
443
443
  return model
444
444
 
445
- with self.profile("load pretrained model"):
446
- pretrained_model = modelpool.load_model("_pretrained_")
447
- with self.profile("load fine-tuned model"):
448
- finetuned_models = [
449
- m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
450
- ]
451
-
452
- if self.config.device == "cuda" and torch.cuda.is_available():
453
- pretrained_model = pretrained_model.cuda()
454
- finetuned_models = [m.cuda() for m in finetuned_models]
445
+ with self.profile("loading model"):
446
+ # load models and move to GPU if available
447
+ with self.profile("load pretrained model"):
448
+ pretrained_model = modelpool.load_model("_pretrained_")
449
+ with self.profile("load fine-tuned model"):
450
+ finetuned_models = [
451
+ m
452
+ for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
453
+ ]
454
+
455
+ if self.config.device == "cuda" and torch.cuda.is_available():
456
+ pretrained_model = pretrained_model.cuda()
457
+ finetuned_models = [m.cuda() for m in finetuned_models]
455
458
 
456
459
  with self.profile("merge model"):
457
460
  model = self.merge(pretrained_model, finetuned_models)
@@ -85,7 +85,14 @@ class CLIPLayerWiseAdaMergingSurgeryAlgorithm(
85
85
 
86
86
  if self.config.weights is not None:
87
87
  # skip the test-time adaptation
88
+ merge_weight: torch.Tensor = torch.load(self.config.weights)
89
+ module.merge_weight.data = merge_weight.to(
90
+ device=module.merge_weight.device
91
+ )
88
92
  merged_model = copy.deepcopy(module.merge_and_unload())
93
+ # setup the zero-shot classification head
94
+ self.on_test_time_adaptation_start()
95
+
89
96
  else:
90
97
  with self.profile("test-time adaptation"):
91
98
  module = self.test_time_adaptation(module)
@@ -6,7 +6,7 @@ http://arxiv.org/abs/2212.04089
6
6
 
7
7
  import logging
8
8
  from copy import deepcopy
9
- from typing import Dict, List, Mapping, TypeVar, Union # noqa: F401
9
+ from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
10
10
 
11
11
  import torch
12
12
  from torch import nn
@@ -19,18 +19,18 @@ from fusion_bench.utils.state_dict_arithmetic import (
19
19
  state_dict_mul,
20
20
  state_dict_sub,
21
21
  )
22
- from fusion_bench.utils.type import StateDictType
22
+ from fusion_bench.utils.type import StateDictType, TorchModelType
23
23
 
24
24
  log = logging.getLogger(__name__)
25
25
 
26
26
 
27
27
  @torch.no_grad()
28
28
  def task_arithmetic_merge(
29
- pretrained_model: nn.Module,
30
- finetuned_models: List[nn.Module],
29
+ pretrained_model: TorchModelType,
30
+ finetuned_models: List[TorchModelType],
31
31
  scaling_factor: float,
32
32
  inplace: bool = True,
33
- ) -> nn.Module:
33
+ ) -> TorchModelType:
34
34
  """
35
35
  Merges the task vectors from multiple fine-tuned models into a single pre-trained model.
36
36
 
@@ -46,15 +46,17 @@ def task_arithmetic_merge(
46
46
  """
47
47
  if not inplace:
48
48
  pretrained_model = deepcopy(pretrained_model)
49
- task_vector: StateDictType = None
49
+ task_vector: Optional[StateDictType] = None
50
50
  # Calculate the total task vector
51
51
  for model in finetuned_models:
52
52
  if task_vector is None:
53
+ # calculate the task vector for the first model
53
54
  task_vector = state_dict_sub(
54
55
  model.state_dict(keep_vars=True),
55
56
  pretrained_model.state_dict(keep_vars=True),
56
57
  )
57
58
  else:
59
+ # calculate the task vector for the remaining models
58
60
  task_vector = state_dict_add(
59
61
  task_vector,
60
62
  state_dict_sub(