fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. fusion_bench/compat/method/__init__.py +3 -1
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/gsm8k.py +2 -2
  6. fusion_bench/method/__init__.py +12 -2
  7. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/doge_ta/__init__.py +2 -0
  10. fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
  11. fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
  12. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  13. fusion_bench/method/gossip/__init__.py +3 -0
  14. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  15. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  16. fusion_bench/method/gossip/entropy_loss.py +25 -0
  17. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  18. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  19. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  20. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  21. fusion_bench/method/gossip/utils.py +74 -0
  22. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  23. fusion_bench/method/opcm/opcm.py +102 -84
  24. fusion_bench/method/opcm/task_arithmetic.py +35 -21
  25. fusion_bench/method/opcm/ties_merging.py +71 -52
  26. fusion_bench/method/pwe_moe/module.py +1 -1
  27. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  28. fusion_bench/method/regmean/regmean.py +25 -17
  29. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  30. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  31. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  32. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  33. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  34. fusion_bench/method/we_moe/we_moe.py +14 -15
  35. fusion_bench/mixins/__init__.py +6 -3
  36. fusion_bench/mixins/hydra_config.py +49 -0
  37. fusion_bench/mixins/openclip_classification.py +11 -0
  38. fusion_bench/mixins/simple_profiler.py +4 -2
  39. fusion_bench/modelpool/__init__.py +3 -1
  40. fusion_bench/modelpool/base_pool.py +2 -2
  41. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  42. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  43. fusion_bench/models/open_clip/__init__.py +6 -0
  44. fusion_bench/models/open_clip/modeling.py +176 -0
  45. fusion_bench/models/open_clip/utils.py +311 -0
  46. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  47. fusion_bench/models/parameter_dict.py +54 -13
  48. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
  49. fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
  50. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  51. fusion_bench/taskpool/__init__.py +5 -3
  52. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  53. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  54. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  55. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  56. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  57. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  58. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  59. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  60. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  61. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  62. fusion_bench/utils/data.py +12 -0
  63. fusion_bench/utils/devices.py +14 -0
  64. fusion_bench/utils/instantiate.py +12 -0
  65. fusion_bench/utils/misc.py +9 -2
  66. fusion_bench/utils/packages.py +14 -0
  67. fusion_bench/utils/parameters.py +1 -1
  68. fusion_bench/utils/tensorboard.py +1 -1
  69. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
  70. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
  71. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  72. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  73. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  74. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  75. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  76. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  77. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  78. fusion_bench_config/fabric/auto.yaml +0 -1
  79. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  80. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  81. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  82. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  83. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  84. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  85. fusion_bench_config/llama_full_finetune.yaml +0 -2
  86. fusion_bench_config/llama_model_fusion.yaml +0 -2
  87. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  88. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  89. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  90. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  91. fusion_bench_config/method/adamerging.yaml +2 -2
  92. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  93. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  94. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  95. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  96. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  97. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  98. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  99. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  100. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  101. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  102. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  103. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  104. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  105. fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
  106. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  107. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  108. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  109. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  110. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  111. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  112. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  113. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  114. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  115. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  116. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  117. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  118. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  119. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  120. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  121. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  122. fusion_bench_config/method/model_recombination.yaml +0 -1
  123. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  124. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  125. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  126. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  127. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  128. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  129. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  130. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  131. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  132. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  133. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  134. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  135. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  136. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  137. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  138. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  139. fusion_bench_config/method/ties_merging.yaml +1 -1
  140. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  141. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  142. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  146. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  147. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  148. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  149. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  150. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  151. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  152. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  161. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  162. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  163. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  164. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  165. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  166. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  167. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  169. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  170. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  171. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  172. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  173. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  174. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  175. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  176. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  177. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  178. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
  179. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
  180. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  181. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  182. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  183. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  184. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  185. fusion_bench_config/nyuv2_config.yaml +0 -2
  186. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  187. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  188. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  189. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  190. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  191. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  192. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  193. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  194. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  195. fusion_bench/method/DOGE_TA/__init__.py +0 -2
  196. /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
  197. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  198. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
  199. {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from fusion_bench.method.ties_merging.ties_merging_utils import (
20
20
  ties_merging,
21
21
  vector_to_state_dict,
22
22
  )
23
- from fusion_bench.mixins import LightningFabricMixin
23
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
24
24
  from fusion_bench.taskpool import CLIPVisionModelTaskPool
25
25
  from fusion_bench.utils.json import load_from_json, save_to_json
26
26
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_sub
@@ -29,7 +29,11 @@ if TYPE_CHECKING:
29
29
  from torch.utils.tensorboard import SummaryWriter
30
30
 
31
31
 
32
- class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
32
+ class ContinualTiesMergingForCLIP(
33
+ BaseAlgorithm,
34
+ LightningFabricMixin,
35
+ SimpleProfilerMixin,
36
+ ):
33
37
  def __init__(
34
38
  self,
35
39
  scaling_factor: float,
@@ -84,68 +88,83 @@ class ContinualTiesMergingForCLIP(BaseAlgorithm, LightningFabricMixin):
84
88
  )
85
89
 
86
90
  # get the average model
87
- pretrained_model = modelpool.load_pretrained_model()
91
+ with self.profile("loading model"):
92
+ pretrained_model = modelpool.load_pretrained_model()
88
93
  merged_model = deepcopy(pretrained_model)
89
94
 
90
95
  for model_idx, model_name in tqdm(
91
96
  enumerate(model_names), desc="Processing models"
92
97
  ):
93
- task_model = modelpool.load_model(model_name)
98
+ with self.profile("loading model"):
99
+ task_model = modelpool.load_model(model_name)
94
100
 
95
- task_vector = state_dict_sub(
96
- task_model.state_dict(),
97
- pretrained_model.state_dict(),
98
- )
99
- if model_idx == 0:
100
- # if is the first model, the merged task vector is equal to the task vector
101
- ties_merging_state_dict = task_vector
102
- else:
103
- # if is not the first model, we need to merge the task vector with the previous merged task vector
104
- merged_tv = state_dict_sub(
105
- merged_model.state_dict(),
101
+ with self.profile("merging model"):
102
+ task_vector = state_dict_sub(
103
+ task_model.state_dict(),
106
104
  pretrained_model.state_dict(),
107
105
  )
108
- tv_flat_checks = torch.vstack(
109
- [
110
- state_dict_to_vector(merged_tv, remove_keys=self.remove_keys),
111
- state_dict_to_vector(task_vector, remove_keys=self.remove_keys),
112
- ]
113
- )
114
- # perform the TIES merging
115
- ties_merging_tv = ties_merging(
116
- tv_flat_checks,
117
- reset_thresh=self.threshold,
118
- merge_func=self.merge_func,
119
- )
120
- # convert the merged task vector back to a state dict
121
- ties_merging_state_dict = vector_to_state_dict(
122
- ties_merging_tv,
123
- merged_model.state_dict(),
124
- remove_keys=self.remove_keys,
125
- )
126
-
127
- for param_name, param in task_model.named_parameters():
128
- if not param.requires_grad:
129
- continue
130
-
131
- merged_param = merged_model.get_parameter(param_name)
132
- new_param = (
133
- merged_param
134
- + self.scaling_factor * ties_merging_state_dict[param_name]
135
- )
136
- merged_model.get_parameter(param_name).data = new_param
106
+ if model_idx == 0:
107
+ # if is the first model, the merged task vector is equal to the task vector
108
+ ties_merging_state_dict = task_vector
109
+ else:
110
+ # if is not the first model, we need to merge the task vector with the previous merged task vector
111
+ merged_tv = state_dict_sub(
112
+ merged_model.state_dict(),
113
+ pretrained_model.state_dict(),
114
+ )
115
+ tv_flat_checks = torch.vstack(
116
+ [
117
+ state_dict_to_vector(
118
+ merged_tv, remove_keys=self.remove_keys
119
+ ),
120
+ state_dict_to_vector(
121
+ task_vector, remove_keys=self.remove_keys
122
+ ),
123
+ ]
124
+ )
125
+ # perform the TIES merging
126
+ ties_merging_tv = ties_merging(
127
+ tv_flat_checks,
128
+ reset_thresh=self.threshold,
129
+ merge_func=self.merge_func,
130
+ )
131
+ # convert the merged task vector back to a state dict
132
+ ties_merging_state_dict = vector_to_state_dict(
133
+ ties_merging_tv,
134
+ merged_model.state_dict(),
135
+ remove_keys=self.remove_keys,
136
+ )
137
+
138
+ for param_name, param in task_model.named_parameters():
139
+ if not param.requires_grad:
140
+ continue
141
+
142
+ merged_param = merged_model.get_parameter(param_name)
143
+ new_param = (
144
+ merged_param
145
+ + self.scaling_factor * ties_merging_state_dict[param_name]
146
+ )
147
+ merged_model.get_parameter(param_name).data = new_param
137
148
 
138
149
  if self.save_on_every_step:
139
- self.save_merged_model(merged_model, model_idx)
150
+ with self.profile("saving model"):
151
+ self.save_merged_model(merged_model, model_idx)
140
152
 
141
153
  if self.evaluate_on_every_step:
142
- self.taskpool._is_setup = False
143
- self.taskpool._test_datasets = DictConfig(
144
- {n: self._test_datasets[n] for n in model_names[: model_idx + 1]}
145
- )
146
- report = self.taskpool.evaluate(deepcopy(merged_model))
147
- save_to_json(report, Path(self.log_dir) / f"report_{model_idx}.json")
148
-
154
+ with self.profile("evaluating model"):
155
+ self.taskpool._is_setup = False
156
+ self.taskpool._test_datasets = DictConfig(
157
+ {
158
+ n: self._test_datasets[n]
159
+ for n in model_names[: model_idx + 1]
160
+ }
161
+ )
162
+ report = self.taskpool.evaluate(deepcopy(merged_model))
163
+ save_to_json(
164
+ report, Path(self.log_dir) / f"report_{model_idx}.json"
165
+ )
166
+
167
+ self.print_profile_summary()
149
168
  return merged_model
150
169
 
151
170
  def save_merged_model(self, merged_model: CLIPVisionModel, step: int):
@@ -1,5 +1,5 @@
1
1
  R"""
2
- this is adapted from
2
+ this is adapted from
3
3
  https://github.com/tanganke/weight-ensembling_MoE/blob/3cbd327cb28c499065f83387472a79829a2e5fee/src/module/dict_moe.py
4
4
  but with some modifications
5
5
  """
@@ -0,0 +1,476 @@
1
+ import logging
2
+ from abc import abstractmethod
3
+ from collections import defaultdict
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import lightning.fabric
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from open_clip.model import ResidualAttentionBlock
15
+ from torch import Tensor, nn
16
+ from tqdm.auto import tqdm
17
+
18
+ from fusion_bench import BaseAlgorithm
19
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
20
+ from fusion_bench.method.task_arithmetic import task_arithmetic_merge
21
+ from fusion_bench.mixins import OpenCLIPClassificationMixin, SimpleProfilerMixin
22
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
23
+ from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
24
+ from fusion_bench.utils import print_parameters, timeit_context
25
+ from fusion_bench.utils.data import InfiniteDataLoader
26
+
27
+ from .module import ParetoWeightEnsemblingModule
28
+ from .phn.solvers import EPOSolver
29
+ from .utils import generate_simplex_grid
30
+
31
+ log = logging.getLogger(__name__)
32
+
33
+
34
+ class PWEMoEAlgorithmForOpenCLIP(
35
+ BaseAlgorithm,
36
+ SimpleProfilerMixin,
37
+ OpenCLIPClassificationMixin,
38
+ ):
39
+ modelpool: OpenCLIPVisionModelPool
40
+
41
+ #! === Training & Validation Data ===
42
+ # setup the datasets and loaders by calling `load_datasets`
43
+ train_datasets: Dict[str, CLIPDataset]
44
+ train_loaders: Dict[str, torch.utils.data.DataLoader]
45
+ train_loader_iters: Dict[str, Iterator[Tuple[torch.Tensor, torch.Tensor]]]
46
+
47
+ test_datasets: Dict[str, CLIPDataset]
48
+ test_loaders: Dict[str, torch.utils.data.DataLoader]
49
+
50
+ def __init__(
51
+ self,
52
+ *,
53
+ #! === Model Architecture Arguments ===
54
+ partial: bool,
55
+ init_lambda: float,
56
+ router_hidden_layers: int,
57
+ checkpoint_path: str,
58
+ #! === Training Arguments ===
59
+ run_train: bool,
60
+ num_steps: int,
61
+ save_interval: int,
62
+ lr: float,
63
+ alpha: float,
64
+ dataloader_kwargs: DictConfig,
65
+ #! === Evaluation Arguments ===
66
+ run_eval: bool,
67
+ num_evaluation_samples: Union[str, int],
68
+ quick_evaluation: bool,
69
+ **kwargs,
70
+ ):
71
+ super().__init__(**kwargs)
72
+ self.partial = partial
73
+ self.init_lambda = init_lambda
74
+ self.router_hidden_layers = router_hidden_layers
75
+ self.lr = lr
76
+ self.num_steps = num_steps
77
+ self.save_interval = save_interval
78
+ self.alpha = alpha
79
+ self.checkpoint_path = checkpoint_path
80
+ self._dataloader_kwargs = dataloader_kwargs
81
+ self.run_train = run_train
82
+ self.run_eval = run_eval
83
+ self.num_evaluation_samples = num_evaluation_samples
84
+ self.quick_evaluation = quick_evaluation
85
+
86
+ def run(self, modelpool: OpenCLIPVisionModelPool):
87
+ self.modelpool = modelpool
88
+
89
+ # setup the MoE model
90
+ model = self.load_model()
91
+ if self.checkpoint_path is not None:
92
+ self.fabric.load(self.checkpoint_path, {"model": model})
93
+
94
+ # setup dataloaders
95
+ self.load_datasets()
96
+
97
+ if self.run_train:
98
+ model = self.train()
99
+ if self.run_eval:
100
+ self.evaluate(model)
101
+ return model
102
+
103
+ @torch.no_grad()
104
+ def load_model(self):
105
+ modelpool = self.modelpool
106
+
107
+ # load models and classification heads
108
+ pretrained_model: ImageEncoder = self.modelpool.load_pretrained_model()
109
+ log.info("pretrained model statistics:")
110
+ print_parameters(pretrained_model, print_fn=log.info)
111
+
112
+ finetuned_models: Dict[str, ImageEncoder] = {}
113
+ for model_name in self.modelpool.model_names:
114
+ finetuned_models[model_name] = modelpool.load_model(model_name)
115
+
116
+ classification_heads: Dict[str, ClassificationHead] = {}
117
+ for model_name in self.modelpool.model_names:
118
+ classification_heads[model_name] = modelpool.load_classification_head(
119
+ model_name
120
+ )
121
+ self.classification_heads = classification_heads
122
+
123
+ self.train_processor = modelpool.train_processor
124
+ self.test_processor = modelpool.test_processor
125
+
126
+ with timeit_context("Building the MoE model"):
127
+ model = deepcopy(pretrained_model)
128
+
129
+ if self.partial:
130
+ log.info("Weight ensembling only the MLPs")
131
+ # weight ensembling only the MLPs, merge the remaining layers using task arithmetic
132
+ model = task_arithmetic_merge(
133
+ pretrained_model=model,
134
+ finetuned_models=list(finetuned_models.values()),
135
+ scaling_factor=self.init_lambda,
136
+ inplace=True,
137
+ )
138
+
139
+ # fix all parameters
140
+ model.requires_grad_(False)
141
+
142
+ for layer_idx in tqdm(
143
+ range(model.model.visual.transformer.layers), desc="Upscaling MLPs"
144
+ ):
145
+ resblock: ResidualAttentionBlock = (
146
+ model.model.visual.transformer.resblocks[layer_idx]
147
+ )
148
+ resblock.mlp = ParetoWeightEnsemblingModule(
149
+ base_model=cast(
150
+ ResidualAttentionBlock,
151
+ pretrained_model.model.visual.transformer.resblocks[
152
+ layer_idx
153
+ ],
154
+ ).mlp,
155
+ expert_models=[
156
+ cast(
157
+ ResidualAttentionBlock,
158
+ m.model.visual.transformer.resblocks[layer_idx],
159
+ ).mlp
160
+ for m in finetuned_models.values()
161
+ ],
162
+ init_lambda=self.init_lambda,
163
+ fix_base_model_and_experts=True,
164
+ router_hidden_layers=self.router_hidden_layers,
165
+ )
166
+ else:
167
+ log.info("Weight ensembling all the layers")
168
+ # weight ensembling all the layers, merge the remaining layers using task arithmetic
169
+ model = task_arithmetic_merge(
170
+ pretrained_model=model,
171
+ finetuned_models=list(finetuned_models.values()),
172
+ scaling_factor=self.init_lambda,
173
+ inplace=True,
174
+ )
175
+ # fix all parameters
176
+ model.requires_grad_(False)
177
+
178
+ for name in [
179
+ "conv1",
180
+ "ln_pre",
181
+ "ln_post",
182
+ # "class_embedding",
183
+ # "positional_embedding",
184
+ ]:
185
+ setattr(
186
+ model.model.visual,
187
+ name,
188
+ ParetoWeightEnsemblingModule(
189
+ base_model=getattr(pretrained_model.model.visual, name),
190
+ expert_models=[
191
+ getattr(m.model.visual, name)
192
+ for m in finetuned_models.values()
193
+ ],
194
+ init_lambda=self.init_lambda,
195
+ fix_base_model_and_experts=True,
196
+ router_hidden_layers=self.router_hidden_layers,
197
+ ),
198
+ )
199
+ for layer_idx in tqdm(
200
+ range(model.model.visual.transformer.layers),
201
+ desc="Upscaling the transformer layers",
202
+ ):
203
+ for name in ["ln_1", "attn", "ln_attn", "ln_2", "mlp"]:
204
+ setattr(
205
+ model.model.visual.transformer.resblocks[layer_idx],
206
+ name,
207
+ ParetoWeightEnsemblingModule(
208
+ base_model=getattr(
209
+ cast(
210
+ ResidualAttentionBlock,
211
+ pretrained_model.model.visual.transformer.resblocks[
212
+ layer_idx
213
+ ],
214
+ ),
215
+ name,
216
+ ),
217
+ expert_models=[
218
+ getattr(
219
+ cast(
220
+ ResidualAttentionBlock,
221
+ m.model.visual.transformer.resblocks[
222
+ layer_idx
223
+ ],
224
+ ),
225
+ name,
226
+ )
227
+ for m in finetuned_models.values()
228
+ ],
229
+ init_lambda=self.init_lambda,
230
+ fix_base_model_and_experts=True,
231
+ router_hidden_layers=self.router_hidden_layers,
232
+ ),
233
+ )
234
+ for name in ["token_embedding", "ln_final"]:
235
+ setattr(
236
+ model.model,
237
+ name,
238
+ ParetoWeightEnsemblingModule(
239
+ base_model=getattr(pretrained_model.model, name),
240
+ expert_models=[
241
+ getattr(m.model, name)
242
+ for m in finetuned_models.values()
243
+ ],
244
+ init_lambda=self.init_lambda,
245
+ fix_base_model_and_experts=True,
246
+ router_hidden_layers=self.router_hidden_layers,
247
+ ),
248
+ )
249
+
250
+ self.model = model
251
+ print_parameters(model, print_fn=log.info)
252
+ return model
253
+
254
+ def load_datasets(self):
255
+ modelpool = self.modelpool
256
+
257
+ # setup the train datasets and loaders
258
+ train_datasets = {}
259
+ train_loaders = {}
260
+ train_loader_iters = {}
261
+ for dataset_name in modelpool.train_dataset_names:
262
+ train_datasets[dataset_name] = modelpool.load_train_dataset(dataset_name)
263
+ train_datasets[dataset_name] = CLIPDataset(
264
+ train_datasets[dataset_name], self.train_processor
265
+ )
266
+ # sanity check
267
+ assert isinstance(train_datasets[dataset_name][0], tuple)
268
+
269
+ # setup the train loaders
270
+ train_loaders[dataset_name] = torch.utils.data.DataLoader(
271
+ train_datasets[dataset_name],
272
+ shuffle=True,
273
+ drop_last=True,
274
+ **self._dataloader_kwargs,
275
+ )
276
+ train_loaders[dataset_name] = self.fabric.setup_dataloaders(
277
+ train_loaders[dataset_name]
278
+ )
279
+ train_loaders[dataset_name] = InfiniteDataLoader(
280
+ train_loaders[dataset_name]
281
+ )
282
+
283
+ # setup the train loader iterators
284
+ train_loader_iters[dataset_name] = iter(train_loaders[dataset_name])
285
+
286
+ self.train_datasets = train_datasets
287
+ self.train_loaders = train_loaders
288
+ self.train_loader_iters = train_loader_iters
289
+
290
+ # setup the test datasets and loaders
291
+ test_datasets = {}
292
+ test_loaders = {}
293
+ for dataset_name in modelpool.test_dataset_names:
294
+ test_datasets[dataset_name] = modelpool.load_test_dataset(dataset_name)
295
+ test_datasets[dataset_name] = CLIPDataset(
296
+ test_datasets[dataset_name], self.test_processor
297
+ )
298
+ test_loaders[dataset_name] = torch.utils.data.DataLoader(
299
+ test_datasets[dataset_name],
300
+ shuffle=False,
301
+ **self._dataloader_kwargs,
302
+ )
303
+ test_loaders[dataset_name] = self.fabric.setup_dataloaders(
304
+ test_loaders[dataset_name]
305
+ )
306
+
307
+ self.test_datasets = test_datasets
308
+ self.test_loaders = test_loaders
309
+
310
+ def compute_loss(self, model: ImageEncoder, ray: Tensor):
311
+ losses = []
312
+ for dataset_idx, dataset_name in enumerate(self.train_datasets):
313
+ batch = next(self.train_loader_iters[dataset_name])
314
+ x, y = batch
315
+
316
+ features = model(x)
317
+ logits = self.classification_heads[dataset_name](features)
318
+
319
+ _loss = F.cross_entropy(logits, y)
320
+ losses.append(_loss)
321
+
322
+ loss = self.aggregate_loss(model, ray, losses)
323
+ return loss
324
+
325
+ @abstractmethod
326
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
327
+ pass
328
+
329
+ def train(self):
330
+ # setup the model
331
+ num_objectives = len(self.modelpool.model_names)
332
+ model = deepcopy(self.model)
333
+ self.classification_heads = {
334
+ t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
335
+ }
336
+
337
+ # set up the optimizer and learning rate scheduler
338
+ optimizer = torch.optim.Adam(
339
+ filter(lambda p: p.requires_grad, model.parameters()),
340
+ lr=self.lr,
341
+ )
342
+ model, optimizer = self.fabric.setup(model, optimizer)
343
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
344
+ optimizer=optimizer, T_max=self.num_steps, eta_min=self.lr * 0.1
345
+ )
346
+
347
+ model.train()
348
+ device = self.fabric.device
349
+ for step_idx in tqdm(
350
+ range(1, 1 + self.num_steps), "training", dynamic_ncols=True
351
+ ):
352
+ # sample a preference ray
353
+ ray = torch.from_numpy(
354
+ np.random.dirichlet((self.alpha,) * num_objectives, 1)
355
+ .astype(np.float32)
356
+ .flatten()
357
+ ).to(device)
358
+ ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
359
+
360
+ loss = self.compute_loss(model, ray)
361
+
362
+ optimizer.zero_grad()
363
+ self.fabric.backward(loss)
364
+ optimizer.step()
365
+
366
+ lr_scheduler.step()
367
+
368
+ self.fabric.log("loss", loss.item(), step=step_idx)
369
+
370
+ if step_idx % self.save_interval == 0 or step_idx == self.num_steps:
371
+ ckpt_dir = Path(self.log_dir) / "checkpoints"
372
+ ckpt_dir.mkdir(exist_ok=True, parents=True)
373
+ self.fabric.save(
374
+ ckpt_dir / f"model_step={step_idx}.ckpt",
375
+ {"model": model},
376
+ )
377
+ return model
378
+
379
+ def evaluate(self, model):
380
+ results = defaultdict(list)
381
+
382
+ num_objectives = len(self.modelpool.model_names)
383
+ device = self.fabric.device
384
+ self.classification_heads = {
385
+ t: h.to(self.fabric.device) for t, h in self.classification_heads.items()
386
+ }
387
+
388
+ if not lightning.fabric.is_wrapped(model):
389
+ model = self.fabric.setup_module(model)
390
+ model.eval()
391
+
392
+ if self.num_evaluation_samples == "equal_weight":
393
+ uniform_grid = np.array(
394
+ [[1 / num_objectives] * num_objectives], dtype=np.float32
395
+ )
396
+ else:
397
+ uniform_grid = generate_simplex_grid(
398
+ num_objectives, self.num_evaluation_samples
399
+ )
400
+ for ray_idx, ray in tqdm(enumerate(uniform_grid), "evaluating samples"):
401
+ results["ray_idx"].append(ray_idx)
402
+ # sample a preference ray
403
+ for i in range(len(ray)):
404
+ results[f"ray_{i}"].append(ray[i])
405
+ ray = torch.from_numpy(ray).to(device)
406
+ ParetoWeightEnsemblingModule.set_preferenece_vector(model, ray)
407
+
408
+ accs = []
409
+ for dataset_idx, dataset_name in enumerate(
410
+ tqdm(
411
+ self.modelpool.test_dataset_names,
412
+ "evaluating datasets",
413
+ leave=False,
414
+ )
415
+ ):
416
+ test_loader = self.test_loaders[dataset_name]
417
+ TOTAL_CORRECT = 0
418
+ TOTAL_COUNT = 0
419
+ for batch_idx, batch in enumerate(
420
+ pbar := tqdm(
421
+ test_loader,
422
+ f"evaluate {dataset_name}",
423
+ leave=False,
424
+ )
425
+ ):
426
+ x, y = batch
427
+
428
+ features = model(x)
429
+ logits = self.classification_heads[dataset_name](features)
430
+ preds = logits.argmax(-1)
431
+
432
+ correct = (preds == y).sum().item()
433
+ TOTAL_CORRECT += correct
434
+ TOTAL_COUNT += len(y)
435
+ acc = TOTAL_CORRECT / TOTAL_COUNT
436
+ pbar.set_postfix_str(f"acc={acc:.2f}")
437
+
438
+ if self.quick_evaluation and batch_idx > 20:
439
+ break
440
+ results[dataset_name].append(acc)
441
+ accs.append(acc)
442
+
443
+ # compute the average accuracy
444
+ if "average" not in self.modelpool.test_dataset_names:
445
+ results["average"].append(np.mean(accs))
446
+
447
+ (df := pd.DataFrame(results)).to_csv(
448
+ Path(self.log_dir) / "result.csv", index=False
449
+ )
450
+ log.info(df)
451
+
452
+
453
+ class PWEMoELinearScalarizationForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
454
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
455
+ loss = 0
456
+ for r, l in zip(ray, losses):
457
+ loss += r * l
458
+ return loss
459
+
460
+
461
+ class PWEMoEExactParetoOptimalForOpenCLIP(PWEMoEAlgorithmForOpenCLIP):
462
+ epo_solver: Optional[EPOSolver] = None
463
+
464
+ def aggregate_loss(self, model: nn.Module, ray: Tensor, losses: Tuple[Tensor]):
465
+ if self.epo_solver is None:
466
+ num_objectives = len(self.modelpool.model_names)
467
+ self.epo_solver = EPOSolver(n_tasks=num_objectives, n_params=None)
468
+ epo_solver = self.epo_solver
469
+
470
+ losses = torch.stack(losses)
471
+ loss = epo_solver.get_weighted_loss(
472
+ losses,
473
+ ray,
474
+ tuple(filter(lambda p: p.requires_grad, model.parameters())),
475
+ )
476
+ return loss