fusion-bench 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. fusion_bench/compat/method/__init__.py +2 -0
  2. fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
  3. fusion_bench/constants/clip_vision.py +22 -0
  4. fusion_bench/dataset/clip_dataset.py +10 -2
  5. fusion_bench/dataset/fer2013.py +1 -0
  6. fusion_bench/dataset/gsm8k.py +2 -2
  7. fusion_bench/method/__init__.py +10 -0
  8. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
  9. fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
  10. fusion_bench/method/gossip/__init__.py +3 -0
  11. fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
  12. fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
  13. fusion_bench/method/gossip/entropy_loss.py +25 -0
  14. fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
  15. fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
  16. fusion_bench/method/gossip/min_norm_solvers.py +227 -0
  17. fusion_bench/method/gossip/task_wise_gossip.py +265 -0
  18. fusion_bench/method/gossip/utils.py +74 -0
  19. fusion_bench/method/isotropic_merging/__init__.py +1 -1
  20. fusion_bench/method/opcm/opcm.py +16 -7
  21. fusion_bench/method/pwe_moe/module.py +1 -1
  22. fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
  23. fusion_bench/method/regmean/regmean.py +25 -17
  24. fusion_bench/method/smile_upscaling/__init__.py +1 -1
  25. fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
  26. fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
  27. fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
  28. fusion_bench/method/ties_merging/ties_merging.py +36 -31
  29. fusion_bench/method/we_moe/we_moe.py +14 -15
  30. fusion_bench/mixins/__init__.py +6 -3
  31. fusion_bench/mixins/hydra_config.py +49 -0
  32. fusion_bench/mixins/openclip_classification.py +11 -0
  33. fusion_bench/mixins/simple_profiler.py +4 -2
  34. fusion_bench/modelpool/__init__.py +3 -1
  35. fusion_bench/modelpool/base_pool.py +2 -2
  36. fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
  37. fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
  38. fusion_bench/models/open_clip/__init__.py +6 -0
  39. fusion_bench/models/open_clip/modeling.py +176 -0
  40. fusion_bench/models/open_clip/utils.py +311 -0
  41. fusion_bench/models/open_clip/variables_and_paths.py +56 -0
  42. fusion_bench/models/parameter_dict.py +54 -13
  43. fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
  44. fusion_bench/taskpool/__init__.py +5 -3
  45. fusion_bench/taskpool/clip_vision/__init__.py +1 -0
  46. fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
  47. fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
  48. fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
  49. fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
  50. fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
  51. fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
  52. fusion_bench/taskpool/gpt2_text_classification.py +30 -1
  53. fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
  54. fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
  55. fusion_bench/utils/data.py +12 -0
  56. fusion_bench/utils/devices.py +14 -0
  57. fusion_bench/utils/instantiate.py +12 -0
  58. fusion_bench/utils/misc.py +9 -2
  59. fusion_bench/utils/packages.py +14 -0
  60. fusion_bench/utils/parameters.py +1 -1
  61. fusion_bench/utils/tensorboard.py +1 -1
  62. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +1 -1
  63. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +190 -151
  64. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
  65. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
  66. fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
  67. fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
  68. fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
  69. fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
  70. fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
  71. fusion_bench_config/fabric/auto.yaml +0 -1
  72. fusion_bench_config/fabric/llama_ddp.yaml +0 -1
  73. fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
  74. fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
  75. fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
  76. fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
  77. fusion_bench_config/fabric_model_fusion.yaml +0 -1
  78. fusion_bench_config/llama_full_finetune.yaml +0 -2
  79. fusion_bench_config/llama_model_fusion.yaml +0 -2
  80. fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
  81. fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
  82. fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
  83. fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
  84. fusion_bench_config/method/adamerging.yaml +2 -2
  85. fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
  86. fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
  87. fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
  88. fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
  89. fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
  90. fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
  91. fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
  92. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
  93. fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
  94. fusion_bench_config/method/dare/simple_average.yaml +0 -1
  95. fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
  96. fusion_bench_config/method/dare/ties_merging.yaml +0 -2
  97. fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
  98. fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
  99. fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
  100. fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
  101. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
  102. fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
  103. fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
  104. fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
  105. fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
  106. fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
  107. fusion_bench_config/method/linear/llama_expo.yaml +0 -3
  108. fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
  109. fusion_bench_config/method/linear/weighted_average.yaml +0 -1
  110. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
  111. fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
  112. fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
  113. fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
  114. fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
  115. fusion_bench_config/method/model_recombination.yaml +0 -1
  116. fusion_bench_config/method/opcm/opcm.yaml +0 -1
  117. fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
  118. fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
  119. fusion_bench_config/method/opcm/weight_average.yaml +0 -1
  120. fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
  121. fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
  122. fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
  123. fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
  124. fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
  125. fusion_bench_config/method/slerp/slerp.yaml +0 -2
  126. fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
  127. fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
  128. fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
  129. fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
  130. fusion_bench_config/method/task_arithmetic.yaml +1 -1
  131. fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
  132. fusion_bench_config/method/ties_merging.yaml +1 -1
  133. fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
  134. fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
  135. fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
  136. fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
  137. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
  138. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
  139. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
  140. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
  141. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
  142. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
  143. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
  144. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
  145. fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
  154. fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
  155. fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
  156. fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
  157. fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
  158. fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
  159. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
  160. fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
  161. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
  162. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
  163. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
  164. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
  165. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
  166. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
  167. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
  168. fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
  169. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
  170. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
  171. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
  172. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
  173. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
  174. fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
  175. fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
  176. fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
  177. fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
  178. fusion_bench_config/nyuv2_config.yaml +0 -2
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
  180. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
  181. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
  182. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
  183. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
  184. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
  185. fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
  186. fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
  187. fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
  188. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
  189. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/licenses/LICENSE +0 -0
  190. {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
+ from fusion_bench.mixins import SimpleProfilerMixin
19
20
  from fusion_bench.modelpool import BaseModelPool
20
21
  from fusion_bench.utils.type import StateDictType
21
22
 
@@ -24,7 +25,7 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
24
25
  log = logging.getLogger(__name__)
25
26
 
26
27
 
27
- class TiesMergingAlgorithm(BaseAlgorithm):
28
+ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
28
29
  """
29
30
  TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
30
31
 
@@ -84,34 +85,38 @@ class TiesMergingAlgorithm(BaseAlgorithm):
84
85
  scaling_factor = self.scaling_factor
85
86
  threshold = self.threshold
86
87
 
87
- # Load the pretrained model
88
- pretrained_model = modelpool.load_model("_pretrained_")
89
-
90
- # Load the state dicts of the models
91
- ft_checks: List[StateDictType] = [
92
- modelpool.load_model(model_name).state_dict(keep_vars=True)
93
- for model_name in modelpool.model_names
94
- ]
95
- ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
96
-
97
- # Compute the task vectors
98
- flat_ft: Tensor = torch.vstack(
99
- [state_dict_to_vector(check, remove_keys) for check in ft_checks]
100
- )
101
- flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
102
- tv_flat_checks = flat_ft - flat_ptm
103
-
104
- # Perform TIES Merging
105
- merged_tv = ties_merging(
106
- tv_flat_checks,
107
- reset_thresh=threshold,
108
- merge_func=merge_func,
109
- )
110
- merged_check = flat_ptm + scaling_factor * merged_tv
111
- merged_state_dict = vector_to_state_dict(
112
- merged_check, ptm_check, remove_keys=remove_keys
113
- )
114
-
115
- # Load the merged state dict into the pretrained model
116
- pretrained_model.load_state_dict(merged_state_dict)
88
+ with self.profile("loading models"):
89
+ # Load the pretrained model
90
+ pretrained_model = modelpool.load_model("_pretrained_")
91
+
92
+ # Load the state dicts of the models
93
+ ft_checks: List[StateDictType] = [
94
+ modelpool.load_model(model_name).state_dict(keep_vars=True)
95
+ for model_name in modelpool.model_names
96
+ ]
97
+ ptm_check: StateDictType = pretrained_model.state_dict(keep_vars=True)
98
+
99
+ with self.profile("merging models"):
100
+ # Compute the task vectors
101
+ flat_ft: Tensor = torch.vstack(
102
+ [state_dict_to_vector(check, remove_keys) for check in ft_checks]
103
+ )
104
+ flat_ptm: Tensor = state_dict_to_vector(ptm_check, remove_keys)
105
+ tv_flat_checks = flat_ft - flat_ptm
106
+
107
+ # Perform TIES Merging
108
+ merged_tv = ties_merging(
109
+ tv_flat_checks,
110
+ reset_thresh=threshold,
111
+ merge_func=merge_func,
112
+ )
113
+ merged_check = flat_ptm + scaling_factor * merged_tv
114
+ merged_state_dict = vector_to_state_dict(
115
+ merged_check, ptm_check, remove_keys=remove_keys
116
+ )
117
+
118
+ # Load the merged state dict into the pretrained model
119
+ pretrained_model.load_state_dict(merged_state_dict)
120
+
121
+ self.print_profile_summary()
117
122
  return pretrained_model
@@ -5,7 +5,6 @@ from typing import cast # noqa: F401
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
7
7
  import torch
8
- from lightning.pytorch.profilers import SimpleProfiler
9
8
  from omegaconf import DictConfig
10
9
  from torch import Tensor
11
10
  from torch.utils.data import DataLoader
@@ -13,6 +12,7 @@ from tqdm.autonotebook import tqdm
13
12
 
14
13
  from fusion_bench.compat.method.base_algorithm import ModelFusionAlgorithm
15
14
  from fusion_bench.compat.modelpool import ModelPool
15
+ from fusion_bench.mixins import SimpleProfilerMixin
16
16
  from fusion_bench.models.we_moe import WeightEnsemblingMoE
17
17
  from fusion_bench.utils import timeit_context
18
18
  from fusion_bench.utils.parameters import print_parameters
@@ -34,7 +34,10 @@ def entropy_loss(logits: Tensor) -> Tensor:
34
34
  return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
35
35
 
36
36
 
37
- class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
37
+ class WeightEnsemblingMoEAlgorithm(
38
+ ModelFusionAlgorithm,
39
+ SimpleProfilerMixin,
40
+ ):
38
41
  """
39
42
  Algorithm for fusing models using Weight Ensembling Mixture of Experts (MoE).
40
43
 
@@ -44,7 +47,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
44
47
  Attributes:
45
48
  _fabric (L.Fabric): The fabric for distributed training.
46
49
  modelpool (ModelPool): The pool of models to be fused.
47
- profiler (SimpleProfiler): The profiler for measuring performance.
48
50
  """
49
51
 
50
52
  _fabric: L.Fabric = None
@@ -66,9 +68,6 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
66
68
  self._fabric.launch()
67
69
  else:
68
70
  assert "No CUDA device available."
69
- self.profiler = SimpleProfiler(
70
- self.config.get("cache_dir", "outputs"), "we_moe_profiler.txt"
71
- )
72
71
 
73
72
  @abstractmethod
74
73
  def load_checkpoint(self, model, checkpoint):
@@ -177,9 +176,9 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
177
176
  for step_idx in pbar:
178
177
  if self.config.use_grad_accumulate:
179
178
  for task in self.modelpool.model_names:
180
- with self.profiler.profile("data time"):
179
+ with self.profile("data time"):
181
180
  batch = next(self.get_shuffled_test_loader_iter(task))
182
- with self.profiler.profile("forward pass"):
181
+ with self.profile("forward pass"):
183
182
  logits = self.compute_logits(module, batch, task)
184
183
  assert (
185
184
  logits.dim() == 2
@@ -187,23 +186,23 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
187
186
  loss = entropy_loss(logits)
188
187
  # .backward() accumulates when .zero_grad() wasn't called
189
188
  # this can save memory
190
- with self.profiler.profile("backward pass"):
189
+ with self.profile("backward pass"):
191
190
  self._fabric.backward(loss, retain_graph=True)
192
191
  else:
193
192
  loss = 0
194
193
  for task in self.modelpool.model_names:
195
- with self.profiler.profile("data time"):
194
+ with self.profile("data time"):
196
195
  batch = next(self.get_shuffled_test_loader_iter(task))
197
- with self.profiler.profile("forward pass"):
196
+ with self.profile("forward pass"):
198
197
  logits = self.compute_logits(module, batch, task)
199
198
  assert (
200
199
  logits.dim() == 2
201
200
  ), f"Expected logits to be 2D, got {logits.dim()}"
202
201
  loss = loss + entropy_loss(logits)
203
- with self.profiler.profile("backward pass"):
202
+ with self.profile("backward pass"):
204
203
  self._fabric.backward(loss, retain_graph=True)
205
204
 
206
- with self.profiler.profile("optimizer step"):
205
+ with self.profile("optimizer step"):
207
206
  optimizer.step()
208
207
  optimizer.zero_grad()
209
208
 
@@ -232,7 +231,7 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
232
231
  )
233
232
  self.load_checkpoint(moe_model, self.config.checkpoint)
234
233
  else:
235
- with self.profiler.profile("test-time adaptation"):
234
+ with self.profile("test-time adaptation"):
236
235
  moe_model = self.test_time_adaptation(moe_model)
237
236
  if self.config.get("save_checkpoint", False):
238
237
  log.info(f"save checkpoint to {self.config.save_checkpoint}")
@@ -243,5 +242,5 @@ class WeightEnsemblingMoEAlgorithm(ModelFusionAlgorithm):
243
242
 
244
243
  # enable sample-wise adaptation
245
244
  moe_model.batch_reduce = False
246
- print(self.profiler.summary())
245
+ self.print_profile_summary()
247
246
  return moe_model
@@ -6,20 +6,23 @@ from typing_extensions import TYPE_CHECKING
6
6
  from fusion_bench.utils.lazy_imports import LazyImporter
7
7
 
8
8
  _import_structure = {
9
+ "clip_classification": ["CLIPClassificationMixin"],
10
+ "fabric_training": ["FabricTrainingMixin"],
11
+ "hydra_config": ["HydraConfigMixin"],
9
12
  "lightning_fabric": ["LightningFabricMixin"],
13
+ "openclip_classification": ["OpenCLIPClassificationMixin"],
10
14
  "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
11
15
  "simple_profiler": ["SimpleProfilerMixin"],
12
- "clip_classification": ["CLIPClassificationMixin"],
13
- "fabric_training": ["FabricTrainingMixin"],
14
16
  }
15
17
 
16
18
  if TYPE_CHECKING:
17
19
  from .clip_classification import CLIPClassificationMixin
18
20
  from .fabric_training import FabricTrainingMixin
21
+ from .hydra_config import HydraConfigMixin
19
22
  from .lightning_fabric import LightningFabricMixin
23
+ from .openclip_classification import OpenCLIPClassificationMixin
20
24
  from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
21
25
  from .simple_profiler import SimpleProfilerMixin
22
-
23
26
  else:
24
27
  sys.modules[__name__] = LazyImporter(
25
28
  __name__,
@@ -0,0 +1,49 @@
1
+ import logging
2
+ import os
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Union
6
+
7
+ import hydra.core.global_hydra
8
+ from hydra import compose, initialize
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ from fusion_bench.utils import import_object, instantiate
12
+ from fusion_bench.utils.instantiate import set_print_function_call
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class HydraConfigMixin:
18
+ """
19
+ A mixin for classes that need to be instantiated from a config file.
20
+ """
21
+
22
+ @classmethod
23
+ def from_config(
24
+ cls,
25
+ config_name: Union[str, Path],
26
+ overrides: Optional[List[str]] = None,
27
+ ):
28
+ if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
29
+ raise RuntimeError("Hydra is not initialized.")
30
+ else:
31
+ cfg = compose(config_name=config_name, overrides=overrides)
32
+
33
+ config_groups = config_name.split("/")[:-1]
34
+ for config_group in config_groups:
35
+ cfg = cfg[config_group]
36
+
37
+ if "_target_" in cfg:
38
+ # if the config has a _target_ key, check if it is equal to the class name
39
+ target_cls = import_object(cfg["_target_"])
40
+ if target_cls != cls:
41
+ log.warning(
42
+ f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
43
+ )
44
+ with set_print_function_call(False):
45
+ obj = instantiate(cfg)
46
+ else:
47
+ obj = cls(**cfg)
48
+
49
+ return obj
@@ -0,0 +1,11 @@
1
+ import logging
2
+
3
+ from fusion_bench.mixins import LightningFabricMixin
4
+ from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
5
+
6
+ log = logging.getLogger(__name__)
7
+
8
+
9
+ class OpenCLIPClassificationMixin(LightningFabricMixin):
10
+ _train_processor = None
11
+ _test_processor = None
@@ -1,5 +1,5 @@
1
1
  from contextlib import contextmanager
2
- from typing import Generator
2
+ from typing import Generator, Optional
3
3
 
4
4
  from lightning.fabric.utilities.rank_zero import rank_zero_only
5
5
  from lightning.pytorch.profilers import SimpleProfiler
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
70
70
  self.profiler.stop(action_name)
71
71
 
72
72
  @rank_zero_only
73
- def print_profile_summary(self):
73
+ def print_profile_summary(self, title: Optional[str] = None):
74
+ if title is not None:
75
+ print(title)
74
76
  print(self.profiler.summary())
75
77
 
76
78
  def __del__(self):
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
6
6
 
7
7
  _import_structure = {
8
8
  "base_pool": ["BaseModelPool"],
9
+ "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
9
10
  "clip_vision": ["CLIPVisionModelPool"],
10
11
  "nyuv2_modelpool": ["NYUv2ModelPool"],
11
12
  "huggingface_automodel": ["AutoModelPool"],
12
- "causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
13
13
  "seq2seq_lm": ["Seq2SeqLMPool"],
14
14
  "PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
15
+ "openclip_vision": ["OpenCLIPVisionModelPool"],
15
16
  "huggingface_gpt2_classification": [
16
17
  "HuggingFaceGPT2ClassificationPool",
17
18
  "GPT2ForSequenceClassificationPool",
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
30
31
  HuggingFaceGPT2ClassificationPool,
31
32
  )
32
33
  from .nyuv2_modelpool import NYUv2ModelPool
34
+ from .openclip_vision import OpenCLIPVisionModelPool
33
35
  from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
34
36
  from .seq2seq_lm import Seq2SeqLMPool
35
37
  from .seq_classification_lm import SeqenceClassificationModelPool
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
- from fusion_bench.mixins import BaseYAMLSerializableModel
10
+ from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
11
11
  from fusion_bench.utils import instantiate, timeit_context
12
12
 
13
13
  __all__ = ["BaseModelPool"]
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
15
15
  log = logging.getLogger(__name__)
16
16
 
17
17
 
18
- class BaseModelPool(BaseYAMLSerializableModel):
18
+ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
19
19
  """
20
20
  A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
21
21
 
@@ -0,0 +1 @@
1
+ from .modelpool import OpenCLIPVisionModelPool
@@ -0,0 +1,255 @@
1
+ import logging
2
+ import pickle
3
+ import sys
4
+ from typing import Callable, Optional, Union, cast
5
+
6
+ import torch
7
+ from datasets import load_dataset
8
+ from omegaconf import DictConfig, OmegaConf
9
+ from torch import nn
10
+
11
+ from fusion_bench.modelpool import BaseModelPool
12
+ from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
13
+ from fusion_bench.utils import instantiate
14
+ from fusion_bench.utils.expr import is_expr_match
15
+ from fusion_bench.utils.packages import _get_package_version, compare_versions
16
+
17
+ log = logging.getLogger(__name__)
18
+
19
+ # Add flag to track if warning has been shown
20
+ _openclip_version_warning_shown = False
21
+
22
+
23
+ def _check_and_redirect_open_clip_modeling():
24
+ global _openclip_version_warning_shown
25
+ if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
26
+ if not _openclip_version_warning_shown:
27
+ log.warning(
28
+ "OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
29
+ )
30
+ _openclip_version_warning_shown = True
31
+ import open_clip.model
32
+ import open_clip.transformer
33
+
34
+ if not hasattr(open_clip.model, "VisualTransformer"):
35
+ open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
36
+ if not hasattr(open_clip.model, "Transformer"):
37
+ open_clip.model.Transformer = open_clip.transformer.Transformer
38
+ if not hasattr(open_clip.model, "ResidualAttentionBlock"):
39
+ open_clip.model.ResidualAttentionBlock = (
40
+ open_clip.transformer.ResidualAttentionBlock
41
+ )
42
+
43
+ try:
44
+ import src
45
+ import src.modeling
46
+ except ImportError:
47
+ if "src" not in sys.modules:
48
+ # redirect the import of `src` to `fusion_bench.models.open_clip`
49
+ import fusion_bench.models.open_clip as open_clip
50
+
51
+ sys.modules["src"] = open_clip
52
+ log.warning(
53
+ "`src` is not imported."
54
+ "Redirecting the import to `fusion_bench.models.open_clip`"
55
+ )
56
+ if "src.modeling" not in sys.modules:
57
+ # redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
58
+ import fusion_bench.models.open_clip.modeling as open_clip_modeling
59
+
60
+ sys.modules["src.modeling"] = open_clip_modeling
61
+ log.warning(
62
+ "`src.modeling` is not imported."
63
+ "Redirecting the import to `fusion_bench.models.open_clip.modeling`"
64
+ )
65
+
66
+
67
+ def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
68
+ if isinstance(model_config, str):
69
+ _check_and_redirect_open_clip_modeling()
70
+ log.info(f"Loading `ClassificationHead` from {model_config}")
71
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
72
+ head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
73
+ elif isinstance(model_config, nn.Module):
74
+ log.info(f"Returning existing model: {model_config}")
75
+ head = model_config
76
+ else:
77
+ head = instantiate(model_config, *args, **kwargs)
78
+ head = cast(ClassificationHead, head)
79
+ return head
80
+
81
+
82
+ class OpenCLIPVisionModelPool(BaseModelPool):
83
+ """
84
+ A model pool for managing OpenCLIP Vision models (models from task vector paper).
85
+ """
86
+
87
+ _train_processor = None
88
+ _test_processor = None
89
+
90
+ def __init__(
91
+ self,
92
+ models: DictConfig,
93
+ classification_heads: Optional[DictConfig] = None,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(models, **kwargs)
97
+ self._classification_heads = classification_heads
98
+
99
+ @property
100
+ def train_processor(self):
101
+ if self._train_processor is None:
102
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
103
+ self._train_processor = encoder.train_preprocess
104
+ if self._test_processor is None:
105
+ self._test_processor = encoder.val_preprocess
106
+ return self._train_processor
107
+
108
+ @property
109
+ def test_processor(self):
110
+ if self._test_processor is None:
111
+ encoder: ImageEncoder = self.load_pretrained_or_first_model()
112
+ if self._train_processor is None:
113
+ self._train_processor = encoder.train_preprocess
114
+ self._test_processor = encoder.val_preprocess
115
+ return self._test_processor
116
+
117
+ def load_model(
118
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
119
+ ) -> ImageEncoder:
120
+ R"""
121
+ The model config can be:
122
+
123
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
124
+ - {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
125
+ - {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
126
+ - Default, load the model using `instantiate` from hydra.
127
+ """
128
+ if (
129
+ isinstance(model_name_or_config, str)
130
+ and model_name_or_config in self._models
131
+ ):
132
+ model_config = self._models[model_name_or_config]
133
+ else:
134
+ model_config = model_name_or_config
135
+ if isinstance(model_config, DictConfig):
136
+ model_config = OmegaConf.to_container(model_config, resolve=True)
137
+
138
+ if isinstance(model_config, str):
139
+ # the model config is a string, which is the path to the model checkpoint in pickle format
140
+ # load the model using `torch.load`
141
+ # this is the original usage in the task arithmetic codebase
142
+ _check_and_redirect_open_clip_modeling()
143
+ log.info(f"loading ImageEncoder from {model_config}")
144
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
145
+ try:
146
+ encoder = torch.load(
147
+ model_config, weights_only=weights_only, *args, **kwargs
148
+ )
149
+ except RuntimeError as e:
150
+ encoder = pickle.load(open(model_config, "rb"))
151
+ elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
152
+ # the model config is a dictionary with the following keys:
153
+ # - model_name: str, the name of the model
154
+ # - pickle_path: str, the path to the binary file (pickle format)
155
+ # load the model from the binary file (pickle format)
156
+ # this is useful when you use a newer version of torchvision
157
+ _check_and_redirect_open_clip_modeling()
158
+ log.info(
159
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
160
+ )
161
+ weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
162
+ try:
163
+ encoder = torch.load(
164
+ model_config["pickle_path"],
165
+ weights_only=weights_only,
166
+ *args,
167
+ **kwargs,
168
+ )
169
+ except RuntimeError as e:
170
+ encoder = pickle.load(open(model_config["pickle_path"], "rb"))
171
+ _encoder = ImageEncoder(model_config["model_name"])
172
+ _encoder.load_state_dict(encoder.state_dict())
173
+ encoder = _encoder
174
+ elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
175
+ # the model config is a dictionary with the following keys:
176
+ # - model_name: str, the name of the model
177
+ # - state_dict_path: str, the path to the state dict file
178
+ # load the model from the state dict file
179
+ log.info(
180
+ f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
181
+ )
182
+ encoder = ImageEncoder(model_config["model_name"])
183
+ encoder.load_state_dict(
184
+ torch.load(
185
+ model_config["state_dict_path"], weights_only=True, *args, **kwargs
186
+ )
187
+ )
188
+ elif isinstance(model_config, nn.Module):
189
+ # the model config is an existing model
190
+ log.info(f"Returning existing model: {model_config}")
191
+ encoder = model_config
192
+ else:
193
+ encoder = super().load_model(model_name_or_config, *args, **kwargs)
194
+ encoder = cast(ImageEncoder, encoder)
195
+
196
+ # setup the train and test processors
197
+ if self._train_processor is None and hasattr(encoder, "train_preprocess"):
198
+ self._train_processor = encoder.train_preprocess
199
+ if self._test_processor is None and hasattr(encoder, "val_preprocess"):
200
+ self._test_processor = encoder.val_preprocess
201
+
202
+ return encoder
203
+
204
+ def load_classification_head(
205
+ self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
206
+ ) -> ClassificationHead:
207
+ R"""
208
+ The model config can be:
209
+
210
+ - A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
211
+ - Default, load the model using `instantiate` from hydra.
212
+ """
213
+ if (
214
+ isinstance(model_name_or_config, str)
215
+ and model_name_or_config in self._classification_heads
216
+ ):
217
+ model_config = self._classification_heads[model_name_or_config]
218
+ else:
219
+ model_config = model_name_or_config
220
+
221
+ head = load_classifier_head(model_config, *args, **kwargs)
222
+ return head
223
+
224
+ def load_train_dataset(self, dataset_name: str, *args, **kwargs):
225
+ dataset_config = self._train_datasets[dataset_name]
226
+ if isinstance(dataset_config, str):
227
+ log.info(
228
+ f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
229
+ )
230
+ dataset = load_dataset(dataset_config, split="train")
231
+ else:
232
+ dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
233
+ return dataset
234
+
235
+ def load_val_dataset(self, dataset_name: str, *args, **kwargs):
236
+ dataset_config = self._val_datasets[dataset_name]
237
+ if isinstance(dataset_config, str):
238
+ log.info(
239
+ f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
240
+ )
241
+ dataset = load_dataset(dataset_config, split="validation")
242
+ else:
243
+ dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
244
+ return dataset
245
+
246
+ def load_test_dataset(self, dataset_name: str, *args, **kwargs):
247
+ dataset_config = self._test_datasets[dataset_name]
248
+ if isinstance(dataset_config, str):
249
+ log.info(
250
+ f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
251
+ )
252
+ dataset = load_dataset(dataset_config, split="test")
253
+ else:
254
+ dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
255
+ return dataset
@@ -0,0 +1,6 @@
1
+ """
2
+ This module contains the support for the open_clip model.
3
+ Modified from https://github.com/nik-dim/tall_masks/
4
+ """
5
+
6
+ from .modeling import ClassificationHead, ImageClassifier, ImageEncoder